mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-05 14:11:29 +00:00
This commit is contained in:
@@ -0,0 +1,410 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem, typename Policy>
|
||||
struct GemmPipelineAgBgCrImplBase
|
||||
{
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
|
||||
using BInDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataType>>;
|
||||
|
||||
template <typename T>
|
||||
using has_bcastpolicy_type = decltype(T::BCastPolicy);
|
||||
|
||||
static constexpr bool IsBCastPolicyBeforeLDSWrite = [] {
|
||||
if constexpr(is_detected<has_bcastpolicy_type, Problem>{})
|
||||
{
|
||||
return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}();
|
||||
|
||||
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite, ADataType, BInDataType>;
|
||||
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
#if defined(__gfx950__)
|
||||
// The combination of pk_int4_t and transposed loading causes compilation errors.
|
||||
// Therefore do not use transposed loading in this case.
|
||||
// Also, transpose load (ds_read_tr) requires specific tile distribution patterns
|
||||
// that only work for certain K warp tile sizes based on data type size:
|
||||
// - For 1-byte types (fp8/bf8): K warp tile <= 64
|
||||
// - For 2-byte types (fp16/bf16): K warp tile <= 32
|
||||
static constexpr bool is_a_load_tr = []() {
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<ADataType, float>)
|
||||
return false;
|
||||
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
}();
|
||||
|
||||
static constexpr bool is_b_load_tr = []() {
|
||||
using WarpTile = typename BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, float>)
|
||||
return false;
|
||||
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
}();
|
||||
#else
|
||||
static constexpr bool is_a_load_tr = false;
|
||||
static constexpr bool is_b_load_tr = false;
|
||||
#endif
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
template <index_t UnaryOpSize = 8,
|
||||
typename DstBlockTile,
|
||||
typename SrcTileWindow,
|
||||
typename DramTileWindowStep>
|
||||
CK_TILE_DEVICE void GlobalPrefetch(DstBlockTile& dst_block_tile,
|
||||
SrcTileWindow& dram_tile_window,
|
||||
const DramTileWindowStep& dram_tile_window_step) const
|
||||
{
|
||||
load_and_convert_tile<UnaryOpSize>(dst_block_tile, dram_tile_window);
|
||||
move_tile_window(dram_tile_window, dram_tile_window_step);
|
||||
}
|
||||
|
||||
template <typename DstBlockWindow, typename SrcTileWindow, typename DramTileWindowStep>
|
||||
CK_TILE_DEVICE void GlobalPrefetchAsync(DstBlockWindow& dst_block_window,
|
||||
SrcTileWindow& dram_tile_window,
|
||||
const DramTileWindowStep& dram_tile_window_step) const
|
||||
{
|
||||
async_load_tile(dst_block_window, dram_tile_window);
|
||||
move_tile_window(dram_tile_window, dram_tile_window_step);
|
||||
}
|
||||
|
||||
template <typename DstTileWindow, typename SrcBlockTile, typename ElementFunction>
|
||||
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
|
||||
const SrcBlockTile& src_block_tile,
|
||||
const ElementFunction& element_func) const
|
||||
{
|
||||
const auto block_tile_tmp = tile_elementwise_in(element_func, src_block_tile);
|
||||
store_tile(lds_tile_window, block_tile_tmp);
|
||||
}
|
||||
|
||||
template <typename DstTileWindow, typename SrcBlockTile>
|
||||
CK_TILE_DEVICE void LocalPrefill(DstTileWindow& lds_tile_window,
|
||||
const SrcBlockTile& src_block_tile) const
|
||||
{
|
||||
store_tile(lds_tile_window, src_block_tile);
|
||||
}
|
||||
|
||||
template <typename DstBlockTile, typename SrcTileWindow, bool LoadTranspose = false>
|
||||
CK_TILE_DEVICE void LocalPrefetch(DstBlockTile& dst_block_tile,
|
||||
const SrcTileWindow& lds_tile_window,
|
||||
bool_constant<LoadTranspose> = {}) const
|
||||
{
|
||||
if constexpr(LoadTranspose)
|
||||
load_tile_transpose(dst_block_tile, lds_tile_window);
|
||||
else
|
||||
load_tile(dst_block_tile, lds_tile_window);
|
||||
}
|
||||
|
||||
template <typename OverrideADataType = ADataType, typename OverrideBDataType = BDataType>
|
||||
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
|
||||
{
|
||||
// A tile in LDS
|
||||
OverrideADataType* __restrict__ p_a_lds = static_cast<OverrideADataType*>(p_smem);
|
||||
constexpr auto a_lds_block_desc =
|
||||
Policy::template MakeALdsBlockDescriptor<Problem, OverrideADataType>();
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
// TODO: LDS alignment should come from Policy!
|
||||
constexpr index_t APackedSize = numeric_traits<OverrideADataType>::PackedSize;
|
||||
constexpr index_t a_lds_block_space_size =
|
||||
sizeof(OverrideADataType) * a_lds_block_desc.get_element_space_size() / APackedSize;
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_least_multiple(a_lds_block_space_size, 16);
|
||||
|
||||
// B tile in LDS
|
||||
OverrideBDataType* __restrict__ p_b_lds = static_cast<OverrideBDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
return make_tuple(std::move(a_lds_block), std::move(b_lds_block));
|
||||
}
|
||||
|
||||
template <typename DramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
dram_block_window_tmp[number<idx>{}].get_window_origin() + offset,
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<DramBlockWindowTmp::size()>{});
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename DramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto CopyADramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_col_major = std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_col_major, number<KPerBlock>, number<MPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_col_major, number<MPerBlock>, number<KPerBlock>>;
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename DramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
dram_block_window_tmp[number<idx>{}].get_window_origin() + offset,
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<DramBlockWindowTmp::size()>{});
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename DramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, DramBlockWindowTmp>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto CopyBDramWindow(const DramBlockWindowTmp& dram_block_window_tmp,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
constexpr bool is_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
using YPerTile = std::conditional_t<is_row_major, number<KPerBlock>, number<NPerBlock>>;
|
||||
using XPerTile = std::conditional_t<is_row_major, number<NPerBlock>, number<KPerBlock>>;
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window =
|
||||
make_tile_window(dram_block_window_tmp.get_bottom_tensor_view(),
|
||||
make_tuple(YPerTile{}, XPerTile{}),
|
||||
dram_block_window_tmp.get_window_origin() + offset,
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
return std::move(a_copy_dram_window);
|
||||
}
|
||||
|
||||
template <typename ALdsTensorView, typename ALdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto MakeALdsWindows(const ALdsTensorView& a_lds_block_view,
|
||||
const ALdsLoadTileDistr&) const
|
||||
{
|
||||
auto a_lds_shape = []() {
|
||||
if constexpr(is_a_load_tr)
|
||||
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
|
||||
auto a_copy_lds_window = make_tile_window(a_lds_block_view, a_lds_shape, {0, 0});
|
||||
|
||||
auto a_lds_load_tile_distr = []() {
|
||||
if constexpr(is_a_load_tr)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename ALdsLoadTileDistr::DstrEncode,
|
||||
typename ALdsTensorView::DataType>::TransposedDstrEncode{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return ALdsLoadTileDistr{};
|
||||
}
|
||||
}();
|
||||
|
||||
auto a_lds_gemm_window =
|
||||
make_tile_window(a_lds_block_view, a_lds_shape, {0, 0}, a_lds_load_tile_distr);
|
||||
|
||||
return make_tuple(std::move(a_copy_lds_window), std::move(a_lds_gemm_window));
|
||||
}
|
||||
|
||||
template <
|
||||
typename ADramBlockWindowTmp,
|
||||
typename ALdsTensorView,
|
||||
typename ALdsLoadTileDistr,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ALdsTensorView>::value, bool>* = nullptr>
|
||||
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const ALdsTensorView& a_lds_block_view,
|
||||
const ALdsLoadTileDistr& a_lds_load_tile_distr,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
|
||||
|
||||
// Create LDS windows
|
||||
auto [a_copy_lds_window, a_lds_gemm_window] =
|
||||
MakeALdsWindows(a_lds_block_view, a_lds_load_tile_distr);
|
||||
|
||||
return make_tuple(std::move(a_copy_dram_window),
|
||||
std::move(a_copy_lds_window),
|
||||
std::move(a_lds_gemm_window));
|
||||
}
|
||||
|
||||
// Unified GetAWindows that supports 1, 2, or 3 LDS buffers
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename ALdsTensorViewsTuple,
|
||||
typename ALdsLoadTileDistr,
|
||||
typename std::enable_if_t<is_detected<is_tuple, ALdsTensorViewsTuple>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto GetAWindows(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const ALdsTensorViewsTuple& a_lds_block_views_tuple,
|
||||
const ALdsLoadTileDistr& a_lds_load_tile_distr,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
// A DRAM tile window for load
|
||||
auto a_copy_dram_window = CopyADramWindow(a_dram_block_window_tmp, offset);
|
||||
|
||||
// Create LDS windows for each buffer
|
||||
constexpr index_t num_buffers = ALdsTensorViewsTuple::size();
|
||||
auto a_lds_windows = generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeALdsWindows(a_lds_block_views_tuple[i], a_lds_load_tile_distr);
|
||||
},
|
||||
number<num_buffers>{});
|
||||
|
||||
// Return: (dram_window, lds_windows_tuple)
|
||||
// lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i)
|
||||
return make_tuple(std::move(a_copy_dram_window), std::move(a_lds_windows));
|
||||
}
|
||||
|
||||
template <typename BLdsTensorView, typename BLdsLoadTileDistr>
|
||||
CK_TILE_DEVICE constexpr auto MakeBLdsWindows(const BLdsTensorView& b_lds_block_view,
|
||||
const BLdsLoadTileDistr&) const
|
||||
{
|
||||
auto b_lds_shape = []() {
|
||||
if constexpr(is_b_load_tr)
|
||||
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
|
||||
auto b_copy_lds_window = make_tile_window(b_lds_block_view, b_lds_shape, {0, 0});
|
||||
|
||||
auto b_lds_load_tile_distr = []() {
|
||||
if constexpr(is_b_load_tr)
|
||||
{
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename BLdsLoadTileDistr::DstrEncode,
|
||||
typename BLdsTensorView::DataType>::TransposedDstrEncode{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return BLdsLoadTileDistr{};
|
||||
}
|
||||
}();
|
||||
|
||||
auto b_lds_gemm_window =
|
||||
make_tile_window(b_lds_block_view, b_lds_shape, {0, 0}, b_lds_load_tile_distr);
|
||||
|
||||
return make_tuple(std::move(b_copy_lds_window), std::move(b_lds_gemm_window));
|
||||
}
|
||||
|
||||
template <
|
||||
typename BDramBlockWindowTmp,
|
||||
typename BLdsTensorView,
|
||||
typename BLdsLoadTileDistr,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, BLdsTensorView>::value, bool>* = nullptr>
|
||||
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BLdsTensorView& b_lds_block_view,
|
||||
const BLdsLoadTileDistr& b_lds_load_tile_distr,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
// A DRAM tile window for load
|
||||
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
|
||||
|
||||
// Create LDS windows
|
||||
auto [b_copy_lds_window, b_lds_gemm_window] =
|
||||
MakeBLdsWindows(b_lds_block_view, b_lds_load_tile_distr);
|
||||
|
||||
return make_tuple(std::move(b_copy_dram_window),
|
||||
std::move(b_copy_lds_window),
|
||||
std::move(b_lds_gemm_window));
|
||||
}
|
||||
|
||||
// Unified GetBWindows that supports 1, 2, or 3 LDS buffers
|
||||
template <typename BDramBlockWindowTmp,
|
||||
typename BLdsTensorViewsTuple,
|
||||
typename BLdsLoadTileDistr,
|
||||
typename std::enable_if_t<is_detected<is_tuple, BLdsTensorViewsTuple>::value, bool>* =
|
||||
nullptr>
|
||||
CK_TILE_DEVICE constexpr auto GetBWindows(const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BLdsTensorViewsTuple& b_lds_block_views_tuple,
|
||||
const BLdsLoadTileDistr& b_lds_load_tile_distr,
|
||||
const array<index_t, 2>& offset = {0, 0}) const
|
||||
{
|
||||
// B DRAM tile window for load
|
||||
auto b_copy_dram_window = CopyBDramWindow(b_dram_block_window_tmp, offset);
|
||||
|
||||
// Create LDS windows for each buffer
|
||||
constexpr index_t num_buffers = BLdsTensorViewsTuple::size();
|
||||
auto b_lds_windows = generate_tuple(
|
||||
[&](auto i) {
|
||||
return MakeBLdsWindows(b_lds_block_views_tuple[i], b_lds_load_tile_distr);
|
||||
},
|
||||
number<num_buffers>{});
|
||||
|
||||
// Return: (dram_window, lds_windows_tuple)
|
||||
// lds_windows_tuple[i] = (copy_lds_window_i, lds_gemm_window_i)
|
||||
return make_tuple(std::move(b_copy_dram_window), std::move(b_lds_windows));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,653 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_async_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAgBgCrCompAsync
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop == 1)
|
||||
{
|
||||
return TailNumber::One;
|
||||
}
|
||||
if(num_loop % PrefetchStages == 1)
|
||||
{
|
||||
return TailNumber::Three;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TailNumber::Two;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
if(tail_number_first_lane == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number_first_lane == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number_first_lane == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number_first_lane == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return (run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::One>{}));
|
||||
}
|
||||
}
|
||||
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
throw std::logic_error(
|
||||
"Invalid TailNumber: Only TailNumber::Three and TailNumber::Two are supported");
|
||||
#endif
|
||||
}
|
||||
|
||||
CK_TILE_HOST static constexpr auto GetName() { return "COMPUTE_ASYNC"; }
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compute optimized pipeline version async; which is based on V4.
|
||||
*
|
||||
* This pipeline introduces asynchronous load from global memory to LDS,
|
||||
* skipping the intermediate loading into pipeline registers.
|
||||
*/
|
||||
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
|
||||
struct GemmPipelineAgBgCrCompAsync : public BaseGemmPipelineAgBgCrCompAsync<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompAsync<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = true;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
|
||||
static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer");
|
||||
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "COMPUTE_ASYNC";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
|
||||
return 2 * smem_size;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
|
||||
{
|
||||
return Policy::template IsTransposeC<Problem>();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
|
||||
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
|
||||
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) /
|
||||
(MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_issue = num_buffer_load_inst;
|
||||
|
||||
static_for<0, num_buffer_load_inst, 1>{}([&](auto i) {
|
||||
// TODO: this will likely need to be redesigned after (1) changes to reading from
|
||||
// LDS and (2) re-profiling
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA : 1
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::DS_READ, 1, 0); // DS read : 1
|
||||
__builtin_amdgcn_sched_group_barrier(LLVMSchedGroupMask::MFMA, 1, 0); // MFMA: 1
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::VMEM_READ, 1, 0); // VMEM read :1
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
LLVMSchedGroupMask::MFMA, C_MFMA_Inst_Num / num_issue - 2, 0); // MFMA : 6
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
// TODO support multi-ABD
|
||||
static_assert(1 == std::tuple_size_v<AsDramBlockWindowTmp>);
|
||||
static_assert(1 == std::tuple_size_v<BsDramBlockWindowTmp>);
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
// TODO currently fused elementwise are not supported
|
||||
ignore = a_element_func;
|
||||
ignore = b_element_func;
|
||||
static_assert(std::is_same_v<remove_cvref_t<decltype(a_element_func)>,
|
||||
element_wise::PassThrough>);
|
||||
static_assert(std::is_same_v<remove_cvref_t<decltype(b_element_func)>,
|
||||
element_wise::PassThrough>);
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"Data Type conflict on A and B matrix input data type.");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
////////////// global window & register /////////////////
|
||||
// A DRAM tile window(s) for load
|
||||
auto a_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
// B DRAM window(s) for load
|
||||
auto b_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<BsLayout::size()>{});
|
||||
|
||||
// this pipeline has a pair of LDS buffers per logical tile
|
||||
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem);
|
||||
auto&& [a_lds_block1, b_lds_block1] =
|
||||
Base::GetABLdsTensorViews(static_cast<char*>(p_smem) + smem_size);
|
||||
|
||||
// set up LDS tile shapes
|
||||
constexpr auto a_lds_shape = []() {
|
||||
if constexpr(is_a_load_tr_v)
|
||||
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
|
||||
constexpr auto b_lds_shape = []() {
|
||||
if constexpr(is_b_load_tr_v)
|
||||
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
|
||||
// LDS tile windows for storing, one per LDS buffer
|
||||
auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0});
|
||||
|
||||
auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0});
|
||||
|
||||
auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0});
|
||||
|
||||
auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0});
|
||||
|
||||
// initialize DRAM window steps, used to advance the DRAM windows
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
|
||||
// read A(0), B(0) from DRAM to LDS window(0)
|
||||
// and advance the DRAM windows
|
||||
Base::GlobalPrefetchAsync(
|
||||
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
|
||||
// initialize block gemm
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// initialize C block tile
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
clear_tile(c_block_tile);
|
||||
|
||||
// read A(1), B(1) from DRAM to LDS window(1)
|
||||
// and advance the DRAM windows
|
||||
Base::GlobalPrefetchAsync(
|
||||
a_copy_lds_window1, a_tile_windows[number<0>{}], a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window1, b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
|
||||
// tile distribution for the register tiles
|
||||
constexpr auto ALdsTileDistr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto BLdsTileDistr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
// register tiles; double buffering -> a register tile corresponds to a LDS tile window
|
||||
ALdsTile a_block_tile0, a_block_tile1;
|
||||
BLdsTile b_block_tile0, b_block_tile1;
|
||||
|
||||
constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
|
||||
if constexpr(is_a_load_tr_v)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename decltype(ALdsTileDistr)::DstrEncode,
|
||||
typename Problem::ADataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return ALdsTileDistr;
|
||||
}();
|
||||
constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() {
|
||||
if constexpr(is_b_load_tr_v)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename decltype(BLdsTileDistr)::DstrEncode,
|
||||
typename Problem::BDataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return BLdsTileDistr;
|
||||
}();
|
||||
|
||||
// LDS tile windows for reading;
|
||||
// they share the data pointer with the LDS windows for storing
|
||||
// but also associate with a distribution to produce a register tile when reading
|
||||
auto a_lds_ld_window0 =
|
||||
make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
|
||||
auto a_lds_ld_window1 =
|
||||
make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
|
||||
auto b_lds_ld_window0 =
|
||||
make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
|
||||
auto b_lds_ld_window1 =
|
||||
make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
|
||||
|
||||
static_assert(!(is_tile_window_linear_v<decltype(a_lds_ld_window0)>) &&
|
||||
!(is_tile_window_linear_v<decltype(a_lds_ld_window1)>) &&
|
||||
!(is_tile_window_linear_v<decltype(b_lds_ld_window0)>) &&
|
||||
!(is_tile_window_linear_v<decltype(b_lds_ld_window1)>),
|
||||
"LDS windows must not be linear");
|
||||
|
||||
// write to LDS window(0) must complete before the local prefetch
|
||||
block_sync_lds_direct_load();
|
||||
// read A(0), B(0) from LDS window(0) to pipeline registers(0)
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
// LDS window(0) contents are overwritten below by global prefetch, need to sync
|
||||
block_sync_lds();
|
||||
// read A(2), B(2) from DRAM to LDS window(0)
|
||||
// and advance the DRAM windows
|
||||
Base::GlobalPrefetchAsync(
|
||||
a_copy_lds_window0, a_tile_windows[number<0>{}], a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window0, b_tile_windows[number<0>{}], b_dram_tile_window_step);
|
||||
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
// we have had 3 global prefetches so far, indexed (0, 1, 2).
|
||||
index_t i_global_read = amd_wave_read_first_lane(3);
|
||||
// alternate ping: (read to register tile(1), use register tile(0) as gemm input)
|
||||
// pong: (read to register tile(0), use register tile(1) as gemm input)
|
||||
do
|
||||
{
|
||||
// ping
|
||||
{
|
||||
// read A(i-1), B(i-1) from LDS window(1) to pipeline registers(1)
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
// LDS window(1) contents are overwritten by global prefetch, need to sync
|
||||
block_sync_lds();
|
||||
// read A(i), B(i) from DRAM to LDS window(1)
|
||||
// and advance the DRAM windows
|
||||
Base::GlobalPrefetchAsync(a_copy_lds_window1,
|
||||
a_tile_windows[number<0>{}],
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(b_copy_lds_window1,
|
||||
b_tile_windows[number<0>{}],
|
||||
b_dram_tile_window_step);
|
||||
// C(i-3) = A(i-3) @ B(i-3)
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
HotLoopScheduler();
|
||||
}
|
||||
// pong
|
||||
{
|
||||
// write to LDS window(0) must complete before the local prefetch
|
||||
block_sync_lds_direct_load();
|
||||
// read A(i), B(i) from LDS window(0) to pipeline registers(0)
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
// LDS window(0) contents are overwritten by global prefetch, need to sync
|
||||
block_sync_lds();
|
||||
// read A(i+1), B(i+1) from DRAM to LDS window(0)
|
||||
// and advance the DRAM windows
|
||||
Base::GlobalPrefetchAsync(a_copy_lds_window0,
|
||||
a_tile_windows[number<0>{}],
|
||||
a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(b_copy_lds_window0,
|
||||
b_tile_windows[number<0>{}],
|
||||
b_dram_tile_window_step);
|
||||
// C(i-2) = A(i-2) @ B(i-2)
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
HotLoopScheduler();
|
||||
}
|
||||
i_global_read += 2;
|
||||
} while(i_global_read < num_loop);
|
||||
}
|
||||
|
||||
// 3 block gemms remaining
|
||||
if constexpr(TailNum == TailNumber::Three)
|
||||
{
|
||||
{
|
||||
// read A(num_loop-1), B(num_loop-1) from LDS window(1) to pipeline registers(1)
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
// C(num_loop-2) = A(num_loop-2) @ B(num_loop-2)
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
}
|
||||
{
|
||||
// write to LDS window(0) must complete before the local prefetch
|
||||
block_sync_lds_direct_load();
|
||||
// read A(num_loop), B(num_loop) from LDS window(0) to pipeline registers(0)
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1)
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
}
|
||||
{
|
||||
// C(num_loop) = A(num_loop) @ B(num_loop)
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
}
|
||||
}
|
||||
else if(TailNum == TailNumber::Two)
|
||||
// 2 block gemms remaining
|
||||
{
|
||||
{
|
||||
// read A(num_loop), B(num_loop) from LDS window(1) to pipeline registers(1)
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
// C(num_loop-1) = A(num_loop-1) @ B(num_loop-1)
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
}
|
||||
{
|
||||
// C(num_loop) = A(num_loop) @ B(num_loop)
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
}
|
||||
}
|
||||
else if(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
b_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
public:
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
element_wise::PassThrough{},
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
element_wise::PassThrough{},
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,132 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Default policy for GemmPipelineAgBgCrCompAsync
|
||||
// Customized methods: MakeALdsBlockDescriptor, MakeBLdsBlockDescriptor
|
||||
// GetBlockGemm implementation is copied from GemmPipelineAgBgCrCompV4DefaultPolicy
|
||||
struct GemmPipelineAgBgCrCompAsyncDefaultPolicy
|
||||
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompAsyncDefaultPolicy>
|
||||
{
|
||||
static constexpr auto ATileAccessPattern = tile_distribution_pattern::warp_raked;
|
||||
static constexpr auto BTileAccessPattern = tile_distribution_pattern::warp_raked;
|
||||
|
||||
template <typename Problem,
|
||||
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
if constexpr(is_a_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better LDS descriptor for performance
|
||||
// This branch is reusing the logic from
|
||||
// UniversalGemmBasePolicy::MakeALdsBlockDescriptor
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<KPerBlock>{}, number<MPerBlock>{}),
|
||||
make_tuple(number<MPerBlock>{}, number<1>{}),
|
||||
number<MPerBlock>{},
|
||||
number<1>{});
|
||||
return a_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<MPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<MPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
if constexpr(is_b_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better LDS descriptor for performance
|
||||
// This branch is reusing the logic from
|
||||
// UniversalGemmBasePolicy::MakeBLdsBlockDescriptor
|
||||
constexpr auto b_lds_block_desc_0 =
|
||||
make_naive_tensor_descriptor(make_tuple(number<KPerBlock>{}, number<NPerBlock>{}),
|
||||
make_tuple(number<NPerBlock>{}, number<1>{}),
|
||||
number<NPerBlock>{},
|
||||
number<1>{});
|
||||
return b_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<NPerBlock>{}, number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
return transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_pass_through_transform(number<NPerBlock>{}),
|
||||
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t vector_size =
|
||||
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
|
||||
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
|
||||
constexpr auto wg_attr_num_access =
|
||||
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
|
||||
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
|
||||
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
|
||||
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
|
||||
: WGAttrNumAccessEnum::Invalid;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,794 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAgBgCrCompV3
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
if constexpr(Problem::BlockGemmShape::NumWarps == 8)
|
||||
return num_loop > 3;
|
||||
else
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(BlockHasHotloop(num_loop) || num_loop == 3)
|
||||
if constexpr(Problem::BlockGemmShape::NumWarps == 8)
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
else
|
||||
return TailNumber::Odd;
|
||||
else if(num_loop == 2)
|
||||
return TailNumber::Even;
|
||||
else
|
||||
return (Problem::BlockGemmShape::NumWarps == 8) ? TailNumber::One : TailNumber::Odd;
|
||||
}
|
||||
|
||||
template <size_t I = 0, typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
|
||||
constexpr auto scenarios = []() {
|
||||
if constexpr(Problem::BlockGemmShape::NumWarps == 8)
|
||||
return std::array<std::pair<bool, ck_tile::TailNumber>, 5>{
|
||||
std::make_pair(false, TailNumber::One), // 1 loop
|
||||
std::make_pair(false, TailNumber::Even), // 2 loop
|
||||
std::make_pair(false, TailNumber::Odd), // 3
|
||||
std::make_pair(true, TailNumber::Even), // 4 / 6 / 8 / ... loops
|
||||
std::make_pair(true, TailNumber::Odd), // 5 / 7 / 9 / ... loops
|
||||
};
|
||||
else
|
||||
return std::array<std::pair<bool, ck_tile::TailNumber>, 3>{
|
||||
std::make_pair(true, TailNumber::Odd),
|
||||
std::make_pair(false, TailNumber::Odd),
|
||||
std::make_pair(false, TailNumber::Even),
|
||||
};
|
||||
}();
|
||||
if(has_hot_loop_first_lane == scenarios[I].first &&
|
||||
tail_number_first_lane == scenarios[I].second)
|
||||
return run_func(bool_constant<scenarios[I].first>{}, constant<scenarios[I].second>{});
|
||||
else if constexpr(I + 1 < scenarios.size())
|
||||
return TailHandler<I + 1>(run_func, has_hot_loop, tail_number);
|
||||
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
// This path should be unreachable in device code if tail_number is valid.
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
// If execution reaches here, it's an invalid combination of arguments.
|
||||
throw std::logic_error("Invalid TailNumber value: must be "
|
||||
"TailNumber::Odd or TailNumber::Even");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 2
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 1
|
||||
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
|
||||
struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
using Base::PrefetchStages;
|
||||
using Base::UsePersistentKernel;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "COMPUTE_V3";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
return concat('_', "pipeline_AgBgCrCompV3",
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', WaveNumM, WaveNumN),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Problem::GetName());
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST static std::string Print()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
|
||||
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
|
||||
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
// Below should be equal to AK1|BK1
|
||||
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) / (MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
auto str = std::stringstream{};
|
||||
|
||||
str << "A/B vector size: " << GetVectorSizeA() << ", " << GetVectorSizeB() << "\n"
|
||||
<< "A/B LDS read/write width: " << A_LDS_Read_Width << ", " << B_LDS_Read_Width << "\n"
|
||||
<< "A/B buffer load inst: " << A_Buffer_Load_Inst_Num << ", " << B_Buffer_Load_Inst_Num
|
||||
<< "\n"
|
||||
<< "A/B LDS write inst: " << A_LDS_Write_Inst_Num << ", " << B_LDS_Write_Inst_Num
|
||||
<< "\n"
|
||||
<< "A/B LDS read inst: " << A_LDS_Read_Inst_Num << ", " << B_LDS_Read_Inst_Num << "\n"
|
||||
<< "C MFMA inst: " << C_MFMA_Inst_Num << "\n"
|
||||
<< "KPack: " << BlockGemm::Traits::KPack << "\n"
|
||||
<< "PrefetchStages: " << PrefetchStages << "\n";
|
||||
return str.str();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemm::WarpGemm::kM;
|
||||
constexpr index_t NPerXDL = BlockGemm::WarpGemm::kN;
|
||||
constexpr index_t KPerXDL = BlockGemm::WarpGemm::WarpGemmAttribute::Impl::kK;
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
// Below should be equal to AK1|BK1
|
||||
constexpr index_t A_LDS_Read_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Read_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_LDS_Write_Width = GetSmemPackA();
|
||||
constexpr index_t B_LDS_Write_Width = GetSmemPackB();
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * A_LDS_Write_Width);
|
||||
constexpr index_t B_LDS_Write_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * B_LDS_Write_Width);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * A_LDS_Read_Width);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * B_LDS_Read_Width);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) /
|
||||
(MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
// A/B split schedule
|
||||
// compiler is likely to use ds_read2 when instruction width smaller than 16bytes
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_write_inst_a = A_LDS_Write_Inst_Num;
|
||||
constexpr auto num_ds_write_inst_b = B_LDS_Write_Inst_Num;
|
||||
|
||||
constexpr auto num_buffer_load_inst_a = A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst_b = B_Buffer_Load_Inst_Num;
|
||||
|
||||
constexpr auto num_mfma_inst = C_MFMA_Inst_Num;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_a_mfma =
|
||||
(num_ds_read_inst_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
constexpr auto num_dsread_b_mfma =
|
||||
(num_ds_read_inst_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
// stage 1
|
||||
// Separate this part?
|
||||
// constexpr auto num_mfma_per_ds_read = sizeof(ComputeDataType) / sizeof(ADataType) >
|
||||
// sizeof(ComputeDataType) /
|
||||
// sizeof(BDataType)
|
||||
// ? sizeof(ComputeDataType) /
|
||||
// sizeof(ADataType) : sizeof(ComputeDataType)
|
||||
// / sizeof(BDataType);
|
||||
constexpr auto num_mfma_stage1 =
|
||||
num_mfma_inst - (num_dsread_a_mfma + num_dsread_b_mfma);
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage1 / (num_buffer_load_inst_a + num_buffer_load_inst_b);
|
||||
constexpr auto num_dswrite_per_issue_a = num_ds_write_inst_a / num_buffer_load_inst_a;
|
||||
constexpr auto num_dswrite_per_issue_b = num_ds_write_inst_b / num_buffer_load_inst_b;
|
||||
|
||||
static_for<0, num_buffer_load_inst_a, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
|
||||
});
|
||||
static_for<0, num_buffer_load_inst_b, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, num_dsread_a_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_ds_read_inst_a - (num_dsread_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
static_for<0, num_dsread_b_mfma, 1>{}([&](auto i) {
|
||||
if constexpr((num_ds_read_inst_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_ds_read_inst_b - (num_dsread_b_mfma - 1) * ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"A/B Dram block window should have the same data type as appropriate "
|
||||
"([A|B]DataType) defined in Problem definition!");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
// ------------------------------------------------------------------------------------
|
||||
// Definitions of all needed tiles
|
||||
|
||||
// A/B tiles in LDS
|
||||
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
// A LDS tile for block GEMM
|
||||
auto&& [a_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
|
||||
// B DRAM tile window for load
|
||||
// B LDS tile window for store
|
||||
// B LDS tile for block GEMM
|
||||
auto&& [b_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
// global read 1
|
||||
|
||||
elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
// main body
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
HotLoopScheduler();
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
|
||||
i += 1;
|
||||
} while(i < (num_loop - 1));
|
||||
}
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
// Leak last MFMA block to epilogue region, cover the potential lds-shuffle
|
||||
// latency
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
block_sync_lds();
|
||||
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function runs the pipeline by wrapping it with the tail handler.
|
||||
*
|
||||
* @note This is used by the persistent gemm kernel variants that don't determine
|
||||
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
|
||||
*/
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief This function runs the pipeline using compile-time known hot loop and tail number.
|
||||
* @param num_loop The number of loop iterations. This is determined at runtime due to e.g.
|
||||
* SplitK.
|
||||
* @note This is used by the kernel variants that are able to determine
|
||||
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
|
||||
*/
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
return operator()(a_dram_block_window_tmp,
|
||||
b_dram_block_window_tmp,
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_number,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Quant operator(), single input: This function runs the pipeline by wrapping it with
|
||||
* the tail handler.
|
||||
*
|
||||
* @note This is used by the persistent gemm kernel variants that don't determine
|
||||
* hot loop and tail number on the host side, e.g. grouped gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_number,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Quant operator(), single input: This function runs the pipeline using compile-time
|
||||
* known hot loop and tail number.
|
||||
* @param num_loop The number of loop iterations. This is determined at runtime due to e.g.
|
||||
* SplitK.
|
||||
* @note This is used by the kernel variants that are able to determine
|
||||
* hot loop and tail number on the host side, e.g. non-persistent gemm kernel.
|
||||
*/
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,823 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAgBgCrCompV4
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
constexpr index_t HotLoopGlobalReads = 2;
|
||||
return num_loop >= (HotLoopGlobalReads + PrefetchStages);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop == 1)
|
||||
{
|
||||
return TailNumber::One;
|
||||
}
|
||||
if(num_loop % PrefetchStages == 1)
|
||||
{
|
||||
return TailNumber::Three;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TailNumber::Two;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
if(tail_number_first_lane == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number_first_lane == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number_first_lane == TailNumber::Three)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Three>{});
|
||||
}
|
||||
else if(tail_number_first_lane == TailNumber::Two)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Two>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return (run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::One>{}));
|
||||
}
|
||||
}
|
||||
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
throw std::logic_error("Invalid TailNumber: Only TailNumber::Full and smaller than "
|
||||
"PrefetchStages are supported.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* @brief Compute optimized pipeline version 4
|
||||
*
|
||||
* This version introduces a dual LDS window mechanism using a ping-pong buffer approach
|
||||
* for more efficient data handling from global memory. Unlike compute version 3, this method
|
||||
* allows one LDS to fetch data from global memory while the other LDS executes warps for MFMA
|
||||
* matrix multiplication. This dual operation helps in keeping the Warp unit continuously busy,
|
||||
* thereby significantly reducing memory load times and enhancing overall performance.
|
||||
*
|
||||
* @note This version shows improved performance over Compute Version 3 with the same block tile.
|
||||
* It is particularly more efficient for large matrices where M, N, and K are greater than 8K,
|
||||
* even when Compute Version 3's block size is twice that of Compute Version 4.
|
||||
*/
|
||||
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV4DefaultPolicy>
|
||||
struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV4<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
static_assert(DoubleSmemBuffer == true, "pipeline requires double smem buffer");
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "COMPUTE_V4";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AgBgCrCompV4",
|
||||
concat('x', MPerBlock, NPerBlock, KPerBlock, BlockSize),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
|
||||
return 2 * smem_size;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
|
||||
{
|
||||
return Policy::template IsTransposeC<Problem>();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0{});
|
||||
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1{});
|
||||
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2{});
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0{});
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
|
||||
|
||||
constexpr index_t A_LDS_Read_Width = KPerXDL;
|
||||
constexpr index_t B_LDS_Read_Width = KPerXDL;
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) /
|
||||
(MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto num_ds_read_inst = num_ds_read_inst_a + num_ds_read_inst_b;
|
||||
constexpr auto num_ds_write_inst = A_LDS_Write_Inst_Num + B_LDS_Write_Inst_Num;
|
||||
constexpr auto num_buffer_load_inst = A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_issue = num_buffer_load_inst;
|
||||
|
||||
static_for<0, num_buffer_load_inst, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, C_MFMA_Inst_Num / num_issue - 3, 0); // MFMA : 5
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"Data Type conflict on A and B matrix input data type.");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
|
||||
// global prefetch 0
|
||||
// global read 0
|
||||
|
||||
////////////// LDS desc, window & register /////////////////
|
||||
constexpr index_t smem_size = Policy::template GetSmemSize<Problem>();
|
||||
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem);
|
||||
auto&& [a_lds_block1, b_lds_block1] =
|
||||
Base::GetABLdsTensorViews(static_cast<char*>(p_smem) + smem_size);
|
||||
|
||||
constexpr auto a_lds_shape = []() {
|
||||
if constexpr(is_a_load_tr_v())
|
||||
return make_tuple(number<KPerBlock>{}, number<MPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<MPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
auto a_copy_lds_window0 = make_tile_window(a_lds_block0, a_lds_shape, {0, 0});
|
||||
auto a_copy_lds_window1 = make_tile_window(a_lds_block1, a_lds_shape, {0, 0});
|
||||
|
||||
constexpr auto b_lds_shape = []() {
|
||||
if constexpr(is_b_load_tr_v())
|
||||
return make_tuple(number<KPerBlock>{}, number<NPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<NPerBlock>{}, number<KPerBlock>{});
|
||||
}();
|
||||
auto b_copy_lds_window0 = make_tile_window(b_lds_block0, b_lds_shape, {0, 0});
|
||||
auto b_copy_lds_window1 = make_tile_window(b_lds_block1, b_lds_shape, {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// Generating a tuple with tile_windows for values A0, A1, ... AN
|
||||
auto a_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
a_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
|
||||
a_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
// Generating a tuple with tile_windows for values B0, B1, ... BN
|
||||
auto b_tile_windows = generate_tuple(
|
||||
[&](auto idx) {
|
||||
return make_tile_window(
|
||||
b_dram_block_window_tmp[number<idx>{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
|
||||
b_dram_block_window_tmp[number<idx>{}].get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
},
|
||||
number<AsLayout::size()>{});
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a tuple
|
||||
// as input.
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
// global read 1
|
||||
|
||||
elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
block_sync_lds();
|
||||
|
||||
constexpr auto ALdsTileDistr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto BLdsTileDistr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
ALdsTile a_block_tile0, a_block_tile1;
|
||||
BLdsTile b_block_tile0, b_block_tile1;
|
||||
|
||||
constexpr auto a_lds_input_tile_distr = [&]() {
|
||||
if constexpr(is_a_load_tr_v())
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
decltype(BlockGemm::MakeABlockDistributionEncode()),
|
||||
typename Problem::ADataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return ALdsTileDistr;
|
||||
}();
|
||||
constexpr auto b_lds_input_tile_distr = [&]() {
|
||||
if constexpr(is_b_load_tr_v())
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
decltype(BlockGemm::MakeBBlockDistributionEncode()),
|
||||
typename Problem::BDataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return BLdsTileDistr;
|
||||
}();
|
||||
auto a_lds_ld_window0 =
|
||||
make_tile_window(a_lds_block0, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
|
||||
auto a_lds_ld_window1 =
|
||||
make_tile_window(a_lds_block1, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
|
||||
auto b_lds_ld_window0 =
|
||||
make_tile_window(b_lds_block0, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
|
||||
auto b_lds_ld_window1 =
|
||||
make_tile_window(b_lds_block1, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
|
||||
|
||||
static_assert(!is_tile_window_linear_v<decltype(a_lds_ld_window0)> &&
|
||||
!is_tile_window_linear_v<decltype(a_lds_ld_window1)> &&
|
||||
!is_tile_window_linear_v<decltype(b_lds_ld_window0)> &&
|
||||
!is_tile_window_linear_v<decltype(b_lds_ld_window1)>,
|
||||
"LDS windows must not be linear");
|
||||
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
elementwise_As_res = load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
// minus 2 because we have ping-pong double buffer.
|
||||
index_t iCounter = amd_wave_read_first_lane(num_loop - 2);
|
||||
do
|
||||
{
|
||||
// ping
|
||||
{
|
||||
block_sync_lds();
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
// gemm
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
HotLoopScheduler();
|
||||
}
|
||||
// pong
|
||||
{
|
||||
block_sync_lds();
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window1, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window1, elementwise_Bs_res);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(a_tile_windows, a_element_func);
|
||||
move_tile_window(a_tile_windows, a_dram_tile_window_step);
|
||||
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(b_tile_windows, b_element_func);
|
||||
move_tile_window(b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
// gemm
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
HotLoopScheduler();
|
||||
}
|
||||
iCounter -= 2;
|
||||
} while(iCounter > 1);
|
||||
}
|
||||
|
||||
// tail 3
|
||||
if(TailNum == TailNumber::Three)
|
||||
{
|
||||
// 3
|
||||
{
|
||||
block_sync_lds();
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window0, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window0, elementwise_Bs_res);
|
||||
}
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
}
|
||||
// 2
|
||||
{
|
||||
block_sync_lds();
|
||||
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
}
|
||||
// 1
|
||||
{
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
else if(TailNum == TailNumber::Two)
|
||||
{
|
||||
// 2
|
||||
{
|
||||
block_sync_lds();
|
||||
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1, is_b_load_tr_v);
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
static_for<0, 8, 1>{}([&](auto) {
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 8, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
// 1
|
||||
{
|
||||
block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
}
|
||||
else if(TailNum == TailNumber::One)
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
constexpr auto tail_num = tail_num_.value;
|
||||
constexpr auto PassThrough = [](auto& e, const auto& x) { e = x; };
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop, tail_num>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
bool has_hot_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
has_hot_loop,
|
||||
tail_number,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,56 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Default policy for GemmPipelineAGmemBGmemCregComputeV4, except the block gemm method, it shares
|
||||
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
|
||||
// UniversalGemm Pipeline Policy.
|
||||
// Default policy class should not be templated, put template on
|
||||
// member functions instead.
|
||||
struct GemmPipelineAgBgCrCompV4DefaultPolicy
|
||||
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV4DefaultPolicy>
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t vector_size =
|
||||
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
|
||||
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
|
||||
constexpr auto wg_attr_num_access =
|
||||
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
|
||||
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
|
||||
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
|
||||
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
|
||||
: WGAttrNumAccessEnum::Invalid;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,488 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v5_default_policy.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed Tensor: register
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAgBgCrCompV5
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
return TailNumber::Empty;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV5DefaultPolicy>
|
||||
struct GemmPipelineAgBgCrCompV5 : public BaseGemmPipelineAgBgCrCompV5<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV5<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using ComputeDataType = remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr index_t NumWarps = BlockGemmShape::NumWarps;
|
||||
static constexpr index_t KTileSize = BlockGemmShape::WarpTile::at(I2{});
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "COMPUTE_V5";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AgBgCrCompV5", BlockSize,
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
|
||||
{
|
||||
return Policy::template IsTransposeC<Problem>();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"Data Type conflict on A and B matrix input data type.");
|
||||
|
||||
static_assert(
|
||||
KPerBlock % ((NumWarps / 2) * KTileSize) == 0,
|
||||
"Ping Pong Warps, TileSize and Block Size for K dimensions does not match.");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
index_t warp_id = get_warp_id();
|
||||
index_t operation_id =
|
||||
amd_wave_read_first_lane(get_warp_id()); // 0 - Memory read, 1 - block-gemm
|
||||
|
||||
auto a_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock);
|
||||
auto b_offset = (warp_id == 0) ? make_array(0, 0) : make_array(0, KPerBlock);
|
||||
|
||||
auto tensor_views =
|
||||
Base::GetABLdsTensorViews(static_cast<void*>(static_cast<char*>(p_smem_0)));
|
||||
auto& a_lds_block = tensor_views.get(number<0>{});
|
||||
auto& b_lds_block = tensor_views.get(number<1>{});
|
||||
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
auto a_windows = Base::GetAWindows(
|
||||
a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr, a_offset);
|
||||
auto& a_copy_dram_window = a_windows.get(number<0>{});
|
||||
auto& a_copy_lds_window = a_windows.get(number<1>{});
|
||||
auto& a_lds_window = a_windows.get(number<2>{});
|
||||
|
||||
auto b_windows = Base::GetBWindows(
|
||||
b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr, b_offset);
|
||||
auto& b_copy_dram_window = b_windows.get(number<0>{});
|
||||
auto& b_copy_lds_window = b_windows.get(number<1>{});
|
||||
auto& b_lds_window = b_windows.get(number<2>{});
|
||||
|
||||
// DRAM window steps.
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock * NumWarps, 0)
|
||||
: make_array(0, KPerBlock * NumWarps);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock * NumWarps, 0)
|
||||
: make_array(0, KPerBlock * NumWarps);
|
||||
|
||||
constexpr auto AGemmTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
constexpr auto BGemmTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
|
||||
using AGemmTile = decltype(make_static_distributed_tensor<ADataType>(AGemmTileDistr));
|
||||
using BGemmTile = decltype(make_static_distributed_tensor<BDataType>(BGemmTileDistr));
|
||||
AGemmTile a_tile_0, a_tile_1;
|
||||
BGemmTile b_tile_0, b_tile_1;
|
||||
|
||||
// Register tile for A and B.
|
||||
using ABlockTileDistr =
|
||||
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
using BBlockTileDistr =
|
||||
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
ABlockTile elementwise_As_res;
|
||||
BBlockTile elementwise_Bs_res;
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile_0 = block_gemm.MakeCBlockTile();
|
||||
auto c_block_tile_1 = block_gemm.MakeCBlockTile();
|
||||
|
||||
CDataType* __restrict__ p_c_lds = static_cast<CDataType*>(p_smem_0);
|
||||
auto c_lds_block_0 =
|
||||
make_naive_tensor_view<address_space_enum::lds>(p_c_lds,
|
||||
make_tuple(MPerBlock, NPerBlock),
|
||||
make_tuple(NPerBlock, 1),
|
||||
number<BlockGemm::Traits::KPack>{},
|
||||
number<1>{});
|
||||
auto c_window_0 = make_tile_window(c_lds_block_0,
|
||||
make_tuple(number<MPerBlock>{}, number<NPerBlock>{}),
|
||||
{0, 0},
|
||||
c_block_tile_1.get_tile_distribution());
|
||||
|
||||
// initialize C
|
||||
if(warp_id == 0)
|
||||
{
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile_1);
|
||||
}
|
||||
|
||||
// define ping, pong steps here as lambda functions.
|
||||
auto MemoryOpsStep = [&](auto idx) {
|
||||
// Memory read half here.
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each
|
||||
// A0, A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
elementwise_As_res = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each
|
||||
// B0, B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
elementwise_Bs_res = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
if constexpr(is_a_col_major)
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
|
||||
if constexpr(is_b_row_major)
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
if(idx == 0)
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_0, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_0, b_lds_window);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefetch(a_tile_1, a_lds_window);
|
||||
Base::LocalPrefetch(b_tile_1, b_lds_window);
|
||||
}
|
||||
};
|
||||
|
||||
auto ComputeStep = [&](auto idx) {
|
||||
if(idx == 0)
|
||||
{
|
||||
block_gemm(c_block_tile_0, a_tile_0, b_tile_0);
|
||||
}
|
||||
else
|
||||
{
|
||||
block_gemm(c_block_tile_1, a_tile_1, b_tile_1);
|
||||
}
|
||||
};
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
MemoryOpsStep(warp_id);
|
||||
}
|
||||
|
||||
index_t num_compute_steps = amd_wave_read_first_lane(num_loop);
|
||||
while(num_compute_steps > 1)
|
||||
{
|
||||
block_sync_lds();
|
||||
operation_id = (operation_id + 1) % NumWaveGroups;
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
MemoryOpsStep(warp_id);
|
||||
}
|
||||
else
|
||||
{
|
||||
ComputeStep(warp_id);
|
||||
}
|
||||
num_compute_steps -= 1;
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
if(operation_id == 0)
|
||||
{
|
||||
ComputeStep(warp_id);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
if(warp_id == 1)
|
||||
{
|
||||
store_tile(c_window_0, c_block_tile_1);
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
if(warp_id == 0)
|
||||
{
|
||||
load_tile(c_block_tile_1, c_window_0);
|
||||
|
||||
constexpr auto s_spans = decltype(c_block_tile_0)::get_distributed_spans();
|
||||
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
|
||||
auto idx2 = make_tuple(idx0, idx1);
|
||||
c_block_tile_0(idx2) += c_block_tile_1(idx2);
|
||||
});
|
||||
});
|
||||
}
|
||||
return c_block_tile_0;
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem_0) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem_0) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem_0);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,64 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Default policy for GemmPipelineAGmemBGmemCregComputeV5, except the block gemm method, it shares
|
||||
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
|
||||
// UniversalGemm Pipeline Policy.
|
||||
// Default policy class should not be templated, put template on
|
||||
// member functions instead.
|
||||
struct GemmPipelineAgBgCrCompV5DefaultPolicy
|
||||
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV5DefaultPolicy>
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
// using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType, // AccDataType
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr index_t GetSmemSizeC()
|
||||
{
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
|
||||
return integer_least_multiple(sizeof(typename Problem::CDataType) * MPerBlock * NPerBlock,
|
||||
16);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
constexpr index_t smem_size_c = GetSmemSizeC<Problem>();
|
||||
|
||||
return smem_size_a + smem_size_b >= smem_size_c ? (smem_size_a + smem_size_b)
|
||||
: (smem_size_c);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,821 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v6_default_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAgBgCrCompV6
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 3;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 2;
|
||||
static constexpr index_t HotloopUnroll = 2;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
if(num_loop % HotloopUnroll == 1)
|
||||
{
|
||||
return TailNumber::Odd;
|
||||
}
|
||||
else
|
||||
{
|
||||
return TailNumber::Even;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
const TailNumber tail_number_first_lane = amd_wave_read_first_lane(tail_number);
|
||||
// Handle all the valid cases.
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
if(tail_number_first_lane == TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number_first_lane == TailNumber::Even)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number_first_lane == TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else if(tail_number_first_lane == TailNumber::Even)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
}
|
||||
// If execution reaches here, it's an invalid tail_number because it wasn't handled above.
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
__builtin_unreachable();
|
||||
#else
|
||||
throw std::logic_error("Invalid TailNumber: Only TailNumber::Odd and TailNumber::Even are "
|
||||
"supported in this pipeline context.");
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
// Compute optimized pipeline
|
||||
// GlobalPrefetchStages: 3
|
||||
// LocalPreFillStages: 1
|
||||
// LocalPreFetchStages: 1
|
||||
// LocalSharedMemoryBuffer: 2
|
||||
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompV6DefaultPolicy>
|
||||
struct GemmPipelineAgBgCrCompV6 : public BaseGemmPipelineAgBgCrCompV6<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAgBgCrCompV6<Problem>;
|
||||
using BasePImpl = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
static_assert(!std::is_same_v<BDataType, pk_int4_t>, "Not implemented");
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr index_t KRepeat = BlockGemm::WarpGemm::kKPerThread / GetSmemPackA();
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<BasePImpl::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<BasePImpl::is_b_load_tr>{};
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "COMPUTE_V6";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AgBgCrCompV6", BlockSize,
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
concat('_', KRepeat),
|
||||
concat('_', DoubleSmemBuffer),
|
||||
concat('_', Preshuffle));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
|
||||
{
|
||||
return Policy::template IsTransposeC<Problem>();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public BasePImpl
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public BasePImpl
|
||||
{
|
||||
CK_TILE_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
constexpr index_t MPerXDL = BlockGemmShape::WarpTile::at(I0);
|
||||
constexpr index_t NPerXDL = BlockGemmShape::WarpTile::at(I1);
|
||||
constexpr index_t KPerXDL = BlockGemmShape::WarpTile::at(I2);
|
||||
|
||||
constexpr index_t WaveSize = 64;
|
||||
constexpr index_t WaveNumM = BlockGemmShape::BlockWarps::at(I0);
|
||||
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1);
|
||||
|
||||
constexpr index_t A_LDS_Read_Width = KPerXDL;
|
||||
constexpr index_t B_LDS_Read_Width = KPerXDL;
|
||||
|
||||
constexpr index_t A_Buffer_Load_Inst_Num =
|
||||
MPerBlock * KPerBlock / (BlockSize * GetVectorSizeA());
|
||||
constexpr index_t B_Buffer_Load_Inst_Num =
|
||||
NPerBlock * KPerBlock / (BlockSize * GetVectorSizeB());
|
||||
|
||||
constexpr index_t A_LDS_Write_Inst_Num = MPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
constexpr index_t B_LDS_Write_Inst_Num = NPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
|
||||
constexpr index_t A_LDS_Read_Inst_Num =
|
||||
WaveNumN * MPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
constexpr index_t B_LDS_Read_Inst_Num =
|
||||
WaveNumM * NPerBlock * KPerBlock / (BlockSize * KPerXDL);
|
||||
|
||||
constexpr index_t C_MFMA_Inst_Num = MPerBlock * NPerBlock * KPerBlock /
|
||||
(BlockSize / WaveSize) /
|
||||
(MPerXDL * NPerXDL * KPerXDL);
|
||||
|
||||
constexpr auto num_ds_read_inst_a =
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? A_LDS_Read_Inst_Num
|
||||
: A_LDS_Read_Inst_Num / 2;
|
||||
constexpr auto num_ds_read_inst_b =
|
||||
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? B_LDS_Read_Inst_Num
|
||||
: B_LDS_Read_Inst_Num / 2;
|
||||
|
||||
constexpr auto mfma_cycle = NPerXDL == 16 ? 16 : 32;
|
||||
|
||||
constexpr auto ds_read_a_issue_cycle =
|
||||
A_LDS_Read_Width * sizeof(ADataType) / APackedSize == 16 ? 8 : 4;
|
||||
constexpr auto ds_read_b_issue_cycle =
|
||||
B_LDS_Read_Width * sizeof(BDataType) / BPackedSize == 16 ? 8 : 4;
|
||||
|
||||
constexpr auto ds_read_a_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_a_issue_cycle - 1) / (2 * ds_read_a_issue_cycle);
|
||||
constexpr auto ds_read_b_mfma_rate =
|
||||
(mfma_cycle - 4 + 2 * ds_read_b_issue_cycle - 1) / (2 * ds_read_b_issue_cycle);
|
||||
|
||||
constexpr auto num_dsread_stage1_a = num_ds_read_inst_a / KRepeat * (KRepeat - 1);
|
||||
constexpr auto num_dsread_stage1_b = num_ds_read_inst_b / KRepeat * (KRepeat - 1);
|
||||
constexpr auto num_dsread_stage3_a = num_ds_read_inst_a / KRepeat;
|
||||
constexpr auto num_dsread_stage3_b = num_ds_read_inst_b / KRepeat;
|
||||
|
||||
constexpr auto num_dsread_stage1_a_mfma =
|
||||
(num_dsread_stage1_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
constexpr auto num_dsread_stage1_b_mfma =
|
||||
(num_dsread_stage1_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
constexpr auto num_dsread_stage3_a_mfma =
|
||||
(num_dsread_stage3_a + ds_read_a_mfma_rate - 1) / ds_read_a_mfma_rate;
|
||||
constexpr auto num_dsread_stage3_b_mfma =
|
||||
(num_dsread_stage3_b + ds_read_b_mfma_rate - 1) / ds_read_b_mfma_rate;
|
||||
|
||||
constexpr auto num_mfma_stage2 = C_MFMA_Inst_Num -
|
||||
num_ds_read_inst_a / ds_read_a_mfma_rate -
|
||||
num_ds_read_inst_b / ds_read_b_mfma_rate;
|
||||
constexpr auto num_mfma_per_issue =
|
||||
num_mfma_stage2 / (A_Buffer_Load_Inst_Num + B_Buffer_Load_Inst_Num);
|
||||
constexpr auto num_dswrite_per_issue_a = A_LDS_Write_Inst_Num / A_Buffer_Load_Inst_Num;
|
||||
constexpr auto num_dswrite_per_issue_b = B_LDS_Write_Inst_Num / B_Buffer_Load_Inst_Num;
|
||||
|
||||
// stage 1
|
||||
static_for<0, num_dsread_stage1_a_mfma, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr((num_dsread_stage1_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_dsread_stage1_a - (num_dsread_stage1_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, num_dsread_stage1_b_mfma, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr((num_dsread_stage1_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_dsread_stage1_b - (num_dsread_stage1_b_mfma - 1) * ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
|
||||
// stage 2
|
||||
static_for<0, A_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_a, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_a, 0); // MFMA
|
||||
});
|
||||
static_for<0, B_Buffer_Load_Inst_Num, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
static_for<0, num_dswrite_per_issue_b, 1>{}([&](auto idswrite) {
|
||||
ignore = idswrite;
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x008, num_mfma_per_issue - num_dswrite_per_issue_b, 0); // MFMA
|
||||
});
|
||||
|
||||
// stage 3
|
||||
static_for<0, num_dsread_stage3_a_mfma, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr((num_dsread_stage3_a - (i + 1) * ds_read_a_mfma_rate) >=
|
||||
ds_read_a_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_a_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_dsread_stage3_a - (num_dsread_stage3_a_mfma - 1) * ds_read_a_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
static_for<0, num_dsread_stage3_b_mfma, 1>{}([&](auto i) {
|
||||
ignore = i;
|
||||
if constexpr((num_dsread_stage3_b - (i + 1) * ds_read_b_mfma_rate) >=
|
||||
ds_read_b_mfma_rate)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, ds_read_b_mfma_rate, 0); // DS read
|
||||
}
|
||||
else
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(
|
||||
0x100,
|
||||
num_dsread_stage3_b - (num_dsread_stage3_b_mfma - 1) * ds_read_b_mfma_rate,
|
||||
0); // DS read
|
||||
}
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
});
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
// TODO: Add Multi A/B support
|
||||
static_assert(std::tuple_size<remove_cvref_t<AsDramBlockWindowTmp>>::value == 1,
|
||||
"Multi A/B is not yet supported for this pipeline.");
|
||||
static_assert(std::tuple_size<remove_cvref_t<BsDramBlockWindowTmp>>::value == 1,
|
||||
"Multi A/B is not yet supported for this pipeline.");
|
||||
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"Data Type conflict on A and B matrix input data type.");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1])
|
||||
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1])
|
||||
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0] &&
|
||||
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
////////////// LDS desc, window & register /////////////////
|
||||
using ALdsType =
|
||||
remove_cvref_t<decltype(BasePImpl::GetABLdsTensorViews(p_smem).at(I0))>;
|
||||
using BLdsType =
|
||||
remove_cvref_t<decltype(BasePImpl::GetABLdsTensorViews(p_smem).at(I1))>;
|
||||
auto&& ABLdsTensorViews = BasePImpl::GetABLdsTensorViews(p_smem);
|
||||
ALdsType& a_lds_block = ABLdsTensorViews.at(I0);
|
||||
BLdsType& b_lds_block = ABLdsTensorViews.at(I1);
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
using acopy_dram_type =
|
||||
remove_cvref_t<decltype(BasePImpl::GetAWindows(a_dram_block_window_tmp,
|
||||
a_lds_block,
|
||||
a_lds_load_tile_distr)
|
||||
.at(I0))>;
|
||||
using bcopy_dram_type =
|
||||
remove_cvref_t<decltype(BasePImpl::GetBWindows(b_dram_block_window_tmp,
|
||||
b_lds_block,
|
||||
b_lds_load_tile_distr)
|
||||
.at(I0))>;
|
||||
|
||||
using a_copy_lds_window_type =
|
||||
remove_cvref_t<decltype(BasePImpl::GetAWindows(a_dram_block_window_tmp,
|
||||
a_lds_block,
|
||||
a_lds_load_tile_distr)
|
||||
.at(I1))>;
|
||||
using b_copy_lds_window_type =
|
||||
remove_cvref_t<decltype(BasePImpl::GetBWindows(b_dram_block_window_tmp,
|
||||
b_lds_block,
|
||||
b_lds_load_tile_distr)
|
||||
.at(I1))>;
|
||||
|
||||
using a_lds_load_tile_distr_type =
|
||||
remove_cvref_t<decltype(BasePImpl::GetAWindows(a_dram_block_window_tmp,
|
||||
a_lds_block,
|
||||
a_lds_load_tile_distr)
|
||||
.at(I2))>;
|
||||
using b_lds_load_tile_distr_type =
|
||||
remove_cvref_t<decltype(BasePImpl::GetBWindows(b_dram_block_window_tmp,
|
||||
b_lds_block,
|
||||
b_lds_load_tile_distr)
|
||||
.at(I2))>;
|
||||
|
||||
auto&& aWindows =
|
||||
BasePImpl::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
auto&& bWindows =
|
||||
BasePImpl::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
// A LDS tile for block GEMM
|
||||
acopy_dram_type& a_copy_dram_window = aWindows.at(I0);
|
||||
a_copy_lds_window_type& a_copy_lds_window = aWindows.at(I1);
|
||||
a_lds_load_tile_distr_type& a_lds_gemm_window = aWindows.at(I2);
|
||||
|
||||
// B DRAM tile window for load
|
||||
// B LDS tile window for store
|
||||
// B LDS tile for block GEMM
|
||||
bcopy_dram_type& b_copy_dram_window = bWindows.at(I0);
|
||||
b_copy_lds_window_type& b_copy_lds_window = bWindows.at(I1);
|
||||
b_lds_load_tile_distr_type& b_lds_gemm_window = bWindows.at(I2);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
using ABlockTileDistr =
|
||||
decltype(a_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
using BBlockTileDistr =
|
||||
decltype(b_copy_dram_window[number<0>{}].get_tile_distribution());
|
||||
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BDataType>(BBlockTileDistr{}));
|
||||
|
||||
ABlockTile a_block_tile[Base::GlobalBufferNum];
|
||||
BBlockTile b_block_tile[Base::GlobalBufferNum];
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
|
||||
|
||||
constexpr auto ALdsTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeABlockDistributionEncode())){};
|
||||
constexpr auto BLdsTileDistr = decltype(make_static_tile_distribution(
|
||||
BlockGemm::MakeBBlockDistributionEncode())){};
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
ALdsTile a_lds_tile;
|
||||
BLdsTile b_lds_tile;
|
||||
// -----------------------------------------------------------------------------------------
|
||||
// Gemm pipeline start
|
||||
|
||||
// Global prefetch 1
|
||||
a_block_tile[I0] = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
b_block_tile[I0] = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// Local prefill 1
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile[I0]);
|
||||
BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
BasePImpl::LocalPrefill(a_copy_lds_window, a_block_tile[I0]);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile[I0]);
|
||||
BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
BasePImpl::LocalPrefill(b_copy_lds_window, b_block_tile[I0]);
|
||||
}
|
||||
|
||||
// Global prefetch 2
|
||||
a_block_tile[I0] = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
b_block_tile[I0] = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// Global prefetch 3
|
||||
a_block_tile[I1] = load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
b_block_tile[I1] = load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// Local prefetch 1
|
||||
BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
|
||||
BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
|
||||
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t i = 0;
|
||||
do
|
||||
{
|
||||
auto LoopFunc = [&](auto vmem_buf_idx) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
if constexpr(k0 == (KRepeat - 1))
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// Local prefill 2
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<
|
||||
Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile[vmem_buf_idx]);
|
||||
BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
BasePImpl::LocalPrefill(a_copy_lds_window,
|
||||
a_block_tile[vmem_buf_idx]);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<
|
||||
Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile[vmem_buf_idx]);
|
||||
BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
BasePImpl::LocalPrefill(b_copy_lds_window,
|
||||
b_block_tile[vmem_buf_idx]);
|
||||
}
|
||||
|
||||
// Global prefetch 4
|
||||
a_block_tile[vmem_buf_idx] =
|
||||
load_tile_with_elementwise(a_copy_dram_window, a_element_func);
|
||||
move_tile_window(a_copy_dram_window, a_dram_tile_window_step);
|
||||
b_block_tile[vmem_buf_idx] =
|
||||
load_tile_with_elementwise(b_copy_dram_window, b_element_func);
|
||||
move_tile_window(b_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
block_gemm(c_block_tile, a_lds_tile, b_lds_tile);
|
||||
|
||||
// Local prefetch 2
|
||||
BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
|
||||
BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
};
|
||||
|
||||
LoopFunc(I0);
|
||||
LoopFunc(I1);
|
||||
|
||||
i += Base::HotloopUnroll;
|
||||
} while(i < (num_loop - Base::PrefetchStages));
|
||||
}
|
||||
|
||||
auto ReadWriteCompFunc = [&](auto vmem_buf_idx) {
|
||||
static_for<0, KRepeat, 1>{}([&](auto k0) {
|
||||
if constexpr(k0 == (KRepeat - 1))
|
||||
{
|
||||
block_sync_lds();
|
||||
|
||||
// Local prefill 3
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, a_block_tile[vmem_buf_idx]);
|
||||
BasePImpl::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
BasePImpl::LocalPrefill(a_copy_lds_window, a_block_tile[vmem_buf_idx]);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, b_block_tile[vmem_buf_idx]);
|
||||
BasePImpl::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
BasePImpl::LocalPrefill(b_copy_lds_window, b_block_tile[vmem_buf_idx]);
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
block_gemm(c_block_tile, a_lds_tile, b_lds_tile);
|
||||
|
||||
BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
|
||||
BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
|
||||
});
|
||||
|
||||
HotLoopScheduler();
|
||||
};
|
||||
|
||||
auto ReadCompFunc = [&]() {
|
||||
static_for<0, KRepeat - 1, 1>{}([&]() {
|
||||
__syncthreads();
|
||||
block_gemm(c_block_tile, a_lds_tile, b_lds_tile);
|
||||
|
||||
// Local prefetch 4
|
||||
BasePImpl::LocalPrefetch(a_lds_tile, a_lds_gemm_window, is_a_load_tr_v);
|
||||
BasePImpl::LocalPrefetch(b_lds_tile, b_lds_gemm_window, is_b_load_tr_v);
|
||||
|
||||
__syncthreads();
|
||||
});
|
||||
|
||||
block_gemm(c_block_tile, a_lds_tile, b_lds_tile);
|
||||
|
||||
HotLoopScheduler();
|
||||
};
|
||||
|
||||
if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
ReadWriteCompFunc(I0);
|
||||
ReadWriteCompFunc(I1);
|
||||
ReadCompFunc();
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
ReadWriteCompFunc(I0);
|
||||
ReadCompFunc();
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
public:
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
[](auto& e, const ADataType& a) { e = a; },
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
[](auto& e, const BDataType& b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
a_element_func,
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,56 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
// Default policy for GemmPipelineAGmemBGmemCregComputeV6, except the block gemm method, it shares
|
||||
// the same vector size implementation, SmemSize, Global memory tile distiribution as the
|
||||
// UniversalGemm Pipeline Policy.
|
||||
// Default policy class should not be templated, put template on
|
||||
// member functions instead.
|
||||
struct GemmPipelineAgBgCrCompV6DefaultPolicy
|
||||
: public UniversalGemmBasePolicy<GemmPipelineAgBgCrCompV6DefaultPolicy>
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t vector_size =
|
||||
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
|
||||
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
|
||||
constexpr auto wg_attr_num_access =
|
||||
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
|
||||
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
|
||||
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
|
||||
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
|
||||
: WGAttrNumAccessEnum::Invalid;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
wg_attr_num_access>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockGemmARegBRegCRegV1<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
} // namespace ck_tile
|
||||
1018
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
Normal file
1018
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct CastPolicy
|
||||
{
|
||||
BeforeLDSWrite,
|
||||
AfterLDSRead,
|
||||
};
|
||||
|
||||
enum struct GemmPipelineScheduler
|
||||
{
|
||||
Default,
|
||||
Intrawave,
|
||||
Interwave,
|
||||
};
|
||||
|
||||
enum struct TailNumber
|
||||
{
|
||||
// Single / Double buffer pipeline
|
||||
Odd,
|
||||
Even,
|
||||
|
||||
// Long prefetch pipeline, up to 8
|
||||
One,
|
||||
Two,
|
||||
Three,
|
||||
Four,
|
||||
Five,
|
||||
Six,
|
||||
Seven,
|
||||
|
||||
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
|
||||
Empty,
|
||||
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
|
||||
// prefetchstages
|
||||
Full,
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os,
|
||||
const ck_tile::GemmPipelineScheduler& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ck_tile::GemmPipelineScheduler::Default: os << "Default"; break;
|
||||
case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break;
|
||||
case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break;
|
||||
default: os << "";
|
||||
}
|
||||
return os;
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<([[clang::lifetimebound]] std::ostream& os,
|
||||
const ck_tile::TailNumber& s)
|
||||
{
|
||||
switch(s)
|
||||
{
|
||||
case ck_tile::TailNumber::Odd: os << "Odd"; break;
|
||||
case ck_tile::TailNumber::Even: os << "Even"; break;
|
||||
case ck_tile::TailNumber::One: os << "One"; break;
|
||||
case ck_tile::TailNumber::Two: os << "Two"; break;
|
||||
case ck_tile::TailNumber::Three: os << "Three"; break;
|
||||
case ck_tile::TailNumber::Four: os << "Four"; break;
|
||||
case ck_tile::TailNumber::Five: os << "Five"; break;
|
||||
case ck_tile::TailNumber::Six: os << "Six"; break;
|
||||
case ck_tile::TailNumber::Seven: os << "Seven"; break;
|
||||
case ck_tile::TailNumber::Empty: os << "Empty"; break;
|
||||
case ck_tile::TailNumber::Full: os << "Full"; break;
|
||||
default: os << "";
|
||||
}
|
||||
return os;
|
||||
}
|
||||
@@ -0,0 +1,377 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "gemm_pipeline_agmem_bgmem_creg_v1.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = GemmPipelineAgBgCrCompAsyncDefaultPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegAsyncV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAGmemBGmemCRegV1<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = true;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "BASIC_ASYNC_V1";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegAsyncV1",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <bool HasHotLoop,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
// TODO support multi-ABD
|
||||
static_assert(1 == std::tuple_size_v<AsDramBlockWindowTmp>);
|
||||
static_assert(1 == std::tuple_size_v<BsDramBlockWindowTmp>);
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
// TODO currently fused elementwise are not supported
|
||||
ignore = a_element_func;
|
||||
ignore = b_element_func;
|
||||
static_assert(std::is_same_v<AElementFunction, element_wise::PassThrough>);
|
||||
static_assert(std::is_same_v<BElementFunction, element_wise::PassThrough>);
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"Data Type conflict on A and B matrix input data type.");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
|
||||
////////////// global window & register /////////////////
|
||||
// A DRAM tile window(s) for load
|
||||
auto a_tile_windows =
|
||||
make_tile_window(a_dram_block_window_tmp[I0{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
|
||||
a_dram_block_window_tmp[I0{}].get_window_origin(),
|
||||
Policy::template MakeADramTileDistribution<Problem>());
|
||||
// B DRAM window(s) for load
|
||||
auto b_tile_windows =
|
||||
make_tile_window(b_dram_block_window_tmp[I0{}].get_bottom_tensor_view(),
|
||||
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
|
||||
b_dram_block_window_tmp[I0{}].get_window_origin(),
|
||||
Policy::template MakeBDramTileDistribution<Problem>());
|
||||
|
||||
// this pipeline has a pair of LDS buffers per logical tile
|
||||
auto&& [a_lds_block, b_lds_block] = Base::GetABLdsTensorViews(p_smem);
|
||||
|
||||
// set up LDS tile shapes
|
||||
constexpr auto a_lds_shape = []() {
|
||||
if constexpr(is_a_load_tr_v)
|
||||
return make_tuple(number<kKPerBlock>{}, number<kMPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{});
|
||||
}();
|
||||
|
||||
constexpr auto b_lds_shape = []() {
|
||||
if constexpr(is_b_load_tr_v)
|
||||
return make_tuple(number<kKPerBlock>{}, number<kNPerBlock>{});
|
||||
else
|
||||
return make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{});
|
||||
}();
|
||||
|
||||
// LDS tile windows for storing, one per LDS buffer
|
||||
auto a_copy_lds_window = make_tile_window(a_lds_block, a_lds_shape, {0, 0});
|
||||
auto b_copy_lds_window = make_tile_window(b_lds_block, b_lds_shape, {0, 0});
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
|
||||
|
||||
// tile distribution for the register tiles
|
||||
constexpr auto ALdsTileDistr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto BLdsTileDistr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
using ALdsTile = decltype(make_static_distributed_tensor<ADataType>(ALdsTileDistr));
|
||||
using BLdsTile = decltype(make_static_distributed_tensor<BDataType>(BLdsTileDistr));
|
||||
|
||||
// register tiles; double buffering -> a register tile corresponds to a LDS tile window
|
||||
ALdsTile a_block_tile;
|
||||
BLdsTile b_block_tile;
|
||||
|
||||
constexpr auto a_lds_input_tile_distr = [ALdsTileDistr]() {
|
||||
if constexpr(is_a_load_tr_v)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename decltype(ALdsTileDistr)::DstrEncode,
|
||||
typename Problem::ADataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return ALdsTileDistr;
|
||||
}();
|
||||
constexpr auto b_lds_input_tile_distr = [BLdsTileDistr]() {
|
||||
if constexpr(is_b_load_tr_v)
|
||||
return make_static_tile_distribution(
|
||||
typename InputTileDistributionTraits<
|
||||
typename decltype(BLdsTileDistr)::DstrEncode,
|
||||
typename Problem::BDataType>::TransposedDstrEncode{});
|
||||
else
|
||||
return BLdsTileDistr;
|
||||
}();
|
||||
|
||||
// LDS tile windows for reading;
|
||||
// they share the data pointer with the LDS windows for storing
|
||||
// but also associate with a distribution to produce a register tile when reading
|
||||
auto a_lds_ld_window =
|
||||
make_tile_window(a_lds_block, a_lds_shape, {0, 0}, a_lds_input_tile_distr);
|
||||
auto b_lds_ld_window =
|
||||
make_tile_window(b_lds_block, b_lds_shape, {0, 0}, b_lds_input_tile_distr);
|
||||
|
||||
static_assert((!(is_tile_window_linear_v<decltype(a_lds_ld_window)>) &&
|
||||
!(is_tile_window_linear_v<decltype(b_lds_ld_window)>)),
|
||||
"LDS windows must not be linear");
|
||||
|
||||
// Global Prefetch
|
||||
Base::GlobalPrefetchAsync(a_copy_lds_window, a_tile_windows, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(b_copy_lds_window, b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
block_sync_lds_direct_load();
|
||||
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
Base::LocalPrefetch(a_block_tile, a_lds_ld_window, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile, b_lds_ld_window, is_b_load_tr_v);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
Base::GlobalPrefetchAsync(
|
||||
a_copy_lds_window, a_tile_windows, a_dram_tile_window_step);
|
||||
Base::GlobalPrefetchAsync(
|
||||
b_copy_lds_window, b_tile_windows, b_dram_tile_window_step);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_block_tile, b_block_tile);
|
||||
|
||||
block_sync_lds_direct_load();
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
Base::LocalPrefetch(a_block_tile, a_lds_ld_window, is_a_load_tr_v);
|
||||
Base::LocalPrefetch(b_block_tile, b_lds_ld_window, is_b_load_tr_v);
|
||||
// GEMM num_loop - 1
|
||||
block_gemm(c_block_tile, a_block_tile, b_block_tile);
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(
|
||||
a_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
b_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,641 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAGmemBGmemCRegV1
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 1;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = false;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
return TailNumber::Empty;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool has_hot_loop)
|
||||
{
|
||||
// Use amd_wave_read_first_lane to avoid higher resource usage.
|
||||
// It forces to store these values in SGPR.
|
||||
// Compiler cannot deduce if one path is used for all threads
|
||||
const bool has_hot_loop_first_lane = amd_wave_read_first_lane(has_hot_loop);
|
||||
|
||||
if(has_hot_loop_first_lane)
|
||||
{
|
||||
return run_func(ck_tile::bool_constant<true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return run_func(ck_tile::bool_constant<false>{});
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = UniversalGemmPipelineAgBgCrPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV1 : public BaseGemmPipelineAGmemBGmemCRegV1<Problem>
|
||||
{
|
||||
using Base = BaseGemmPipelineAGmemBGmemCRegV1<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
using I0 = number<0>;
|
||||
using I1 = number<1>;
|
||||
using I2 = number<2>;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Policy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Policy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr auto Scheduler = Problem::Scheduler;
|
||||
|
||||
static constexpr auto is_a_load_tr_v = bool_constant<PipelineImplBase::is_a_load_tr>{};
|
||||
static constexpr auto is_b_load_tr_v = bool_constant<PipelineImplBase::is_b_load_tr>{};
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "BASIC_V1";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegV1",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
template <GemmPipelineScheduler Scheduler>
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Intrawave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <bool HasHotLoop,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
|
||||
kLdsAlignmentInBytes) *
|
||||
kLdsAlignmentInBytes;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
// A LDS tile for block GEMM
|
||||
auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
|
||||
// B DRAM tile window for load
|
||||
// B LDS tile window for store
|
||||
// B LDS tile for block GEMM
|
||||
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
|
||||
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step =
|
||||
is_a_col_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step =
|
||||
is_b_row_major ? make_array(kKPerBlock, 0) : make_array(0, kKPerBlock);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
{
|
||||
// move to 1
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// global read i + 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
block_sync_lds();
|
||||
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, a_dram_tile_window_step);
|
||||
move_tile_window(bs_copy_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(
|
||||
a_lds_gemm_window, b_lds_gemm_window, is_a_load_tr_v, is_b_load_tr_v);
|
||||
// GEMM num_loop - 1
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct PipelineImpl<GemmPipelineScheduler::Interwave> : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <bool HasHotLoop,
|
||||
typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
constexpr bool is_a_col_major =
|
||||
std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
|
||||
constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
|
||||
|
||||
static_assert(is_a_col_major
|
||||
? (kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"A block window has incorrect lengths for defined ALayout!");
|
||||
static_assert(is_b_row_major
|
||||
? (kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
|
||||
: (kNPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
|
||||
kKPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
|
||||
"B block window has incorrect lengths for defined BLayout!");
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
|
||||
kLdsAlignmentInBytes) *
|
||||
kLdsAlignmentInBytes;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// // Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
// A LDS tile for block GEMM
|
||||
auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
|
||||
// B DRAM tile window for load
|
||||
// B LDS tile window for store
|
||||
// B LDS tile for block GEMM
|
||||
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
{
|
||||
// move to 1
|
||||
// Move each A — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
// Move each B — the enhanced function move_tile_window is executed, which takes a
|
||||
// tuple as input.
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t iCounter = num_loop - 1;
|
||||
while(iCounter > 0)
|
||||
{
|
||||
// global read i + 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
block_sync_lds();
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile,
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window,
|
||||
is_a_load_tr_v,
|
||||
is_b_load_tr_v);
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
if constexpr(is_a_col_major && !is_a_load_tr_v())
|
||||
{
|
||||
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
|
||||
Policy::template MakeShuffledARegTileDistribution<Problem>());
|
||||
transpose_tile2d(a_shuffle_tmp, elementwise_As_res);
|
||||
Base::LocalPrefill(a_copy_lds_window, a_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(a_copy_lds_window, elementwise_As_res);
|
||||
}
|
||||
if constexpr(is_b_row_major && !is_b_load_tr_v())
|
||||
{
|
||||
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
|
||||
Policy::template MakeShuffledBRegTileDistribution<Problem>());
|
||||
transpose_tile2d(b_shuffle_tmp, elementwise_Bs_res);
|
||||
Base::LocalPrefill(b_copy_lds_window, b_shuffle_tmp);
|
||||
}
|
||||
else
|
||||
{
|
||||
Base::LocalPrefill(b_copy_lds_window, elementwise_Bs_res);
|
||||
}
|
||||
|
||||
iCounter--;
|
||||
}
|
||||
}
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
// GEMM num_loop - 1
|
||||
block_gemm(c_block_tile,
|
||||
a_lds_gemm_window,
|
||||
b_lds_gemm_window,
|
||||
is_a_load_tr_v,
|
||||
is_b_load_tr_v);
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(
|
||||
a_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
b_dram_block_window_tmp,
|
||||
element_wise::PassThrough{},
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const bool has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_) {
|
||||
constexpr bool hot_loop = hot_loop_.value;
|
||||
return PipelineImpl<Scheduler>{}.template operator()<hot_loop>(a_dram_block_window_tmp,
|
||||
a_element_func,
|
||||
b_dram_block_window_tmp,
|
||||
b_element_func,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,426 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for GemmPipelineAGmemBGmemCRegV1
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
{
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
|
||||
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kMPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
|
||||
// 3d + padding
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}),
|
||||
make_tuple(number<(kNPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
|
||||
number<8>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(make_pass_through_transform(kNPerBlock),
|
||||
make_merge_transform(make_tuple(kKPerBlock / 8, 8))),
|
||||
make_tuple(sequence<1>{}, sequence<0, 2>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<typename Problem::ADataType>>::PackedSize;
|
||||
constexpr index_t smem_size_a =
|
||||
sizeof(typename Problem::ADataType) *
|
||||
MakeALdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<typename Problem::BDataType>>::PackedSize;
|
||||
constexpr index_t smem_size_b =
|
||||
sizeof(typename Problem::BDataType) *
|
||||
MakeBLdsBlockDescriptor<Problem>().get_element_space_size() / PackedSize;
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
constexpr index_t smem_size = smem_size_a + smem_size_b;
|
||||
|
||||
return smem_size;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA()
|
||||
{
|
||||
using A = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
|
||||
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackB()
|
||||
{
|
||||
using B = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B));
|
||||
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t M1 = Problem::VectorSizeA;
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
static_assert(KPack % K3 == 0);
|
||||
constexpr index_t K2 = KPack / K3;
|
||||
if constexpr(get_warp_size() >= (K2 * M0))
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * M0);
|
||||
constexpr index_t K0 = BlockSize / get_warp_size();
|
||||
static_assert(KPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * M0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
|
||||
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = 16 / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
if constexpr(get_warp_size() % (M2 * K0) == 0)
|
||||
{
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M2, M1 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M0 = BlockSize / get_warp_size();
|
||||
constexpr index_t M1 = MPerBlock / (M2 * M0);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M1, M2 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t N1 = Problem::VectorSizeB;
|
||||
constexpr index_t N0 = NPerBlock / N1;
|
||||
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % N1 == 0);
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
static_assert(KPack % K3 == 0);
|
||||
constexpr index_t K2 = KPack / K3;
|
||||
if constexpr(get_warp_size() >= (K2 * N0))
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * N0);
|
||||
constexpr index_t K0 = BlockSize / get_warp_size();
|
||||
static_assert(KPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
|
||||
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t N2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
if constexpr(get_warp_size() % (N2 * K0) == 0)
|
||||
{
|
||||
constexpr index_t N1 = BlockSize / get_warp_size();
|
||||
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t N0 = NPerBlock / (N2 * N1);
|
||||
static_assert(N0 * N1 * N2 == NPerBlock,
|
||||
"Incorrect N0, N1, N2 configuration! "
|
||||
"N0, N1, N2 must cover whole NPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
// coalesce reading for each warps
|
||||
else
|
||||
{
|
||||
constexpr index_t N0 = BlockSize / get_warp_size();
|
||||
constexpr index_t N1 = NPerBlock / (N2 * N0);
|
||||
static_assert(N0 * N1 * N2 == NPerBlock,
|
||||
"Incorrect N0, N1, N2 configuration! "
|
||||
"N0, N1, N2 must cover whole NPerBlock!");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDistribution()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
|
||||
constexpr index_t N0 = kNPerBlock / N1;
|
||||
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % N1 == 0);
|
||||
constexpr index_t K3 = total_pixels / N1;
|
||||
constexpr index_t kKPack = GetSmemPackB<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
if constexpr(warp_size % (K2 * N0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = warp_size / (K2 * N0);
|
||||
constexpr index_t K0 = kBlockSize / warp_size;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * N0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = kMPerBlock / M1;
|
||||
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
if constexpr(warp_size % (K2 * M0) == 0)
|
||||
{
|
||||
constexpr index_t K1 = warp_size / (K2 * M0);
|
||||
constexpr index_t K0 = kBlockSize / warp_size;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * M0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using AccDataType = float;
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
using WarpGemm = WarpGemmDispatcher<typename Problem::ComputeDataType,
|
||||
typename Problem::ComputeDataType,
|
||||
AccDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
|
||||
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,342 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseGemmPipelineAGmemBGmemCRegV2
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t) { return true; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t)
|
||||
{
|
||||
return TailNumber::Empty;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto TailHandler(const RunFunction& run_func, bool, TailNumber)
|
||||
{
|
||||
return run_func(bool_constant<true>{}, integral_constant<TailNumber, TailNumber::Empty>{});
|
||||
}
|
||||
};
|
||||
// A Tile Window: global memory
|
||||
// B Tile Window: global memory
|
||||
// C Distributed tensor: register
|
||||
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCRegV2DefaultPolicy>
|
||||
struct GemmPipelineAGmemBGmemCRegV2 : public BaseGemmPipelineAGmemBGmemCRegV2<Problem>
|
||||
{
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockGemm = remove_cvref_t<decltype(Policy::template GetBlockGemm<Problem>())>;
|
||||
|
||||
static constexpr index_t APackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
static constexpr index_t BPackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr bool Async = false;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return Problem::VectorSizeA;
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return Problem::VectorSizeB;
|
||||
}
|
||||
static constexpr index_t GetVectorSizeC() { return Problem::VectorSizeC; }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr bool Preshuffle = Problem::Preshuffle;
|
||||
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
// For the basic gemm pipelien DoubleSmemBuffer set to be false naturally.
|
||||
static constexpr bool DoubleSmemBuffer = false;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "BASIC_V2";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegV2",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize));
|
||||
// clang-format on
|
||||
}
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
|
||||
{
|
||||
return integer_divide_ceil(sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>()
|
||||
.get_element_space_size() /
|
||||
APackedSize,
|
||||
16) *
|
||||
16 +
|
||||
sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size() /
|
||||
BPackedSize;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
}
|
||||
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_HOST_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const AElementFunction& a_element_func,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
|
||||
using ADramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, AsDramBlockWindowTmp>>;
|
||||
using BDramBlockWindowTmp =
|
||||
remove_cvref_t<std::tuple_element_t<number<0>{}, BsDramBlockWindowTmp>>;
|
||||
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>> &&
|
||||
std::is_same_v<BDataType,
|
||||
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kNPerBlock ==
|
||||
BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
|
||||
kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// A tile in LDS
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(p_smem);
|
||||
|
||||
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
|
||||
constexpr index_t a_lds_block_space_size_aligned =
|
||||
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size() /
|
||||
APackedSize,
|
||||
16) *
|
||||
16;
|
||||
|
||||
// B tile in LDS
|
||||
BDataType* p_b_lds = static_cast<BDataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
|
||||
|
||||
constexpr auto b_lds_block_desc = Policy::template MakeBLdsBlockDescriptor<Problem>();
|
||||
|
||||
auto b_lds_block = make_tensor_view<address_space_enum::lds>(p_b_lds, b_lds_block_desc);
|
||||
|
||||
// Tile distribution for load from lds
|
||||
constexpr auto a_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeABlockDistributionEncode());
|
||||
constexpr auto b_lds_load_tile_distr =
|
||||
make_static_tile_distribution(BlockGemm::MakeBBlockDistributionEncode());
|
||||
|
||||
// A DRAM tile window for load
|
||||
// A LDS tile window for store
|
||||
// A LDS tile for block GEMM
|
||||
auto&& [as_copy_dram_window, a_copy_lds_window, a_lds_gemm_window] =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_block, a_lds_load_tile_distr);
|
||||
|
||||
// B DRAM tile window for load
|
||||
// B LDS tile window for store
|
||||
// B LDS tile for block GEMM
|
||||
auto&& [bs_copy_dram_window, b_copy_lds_window, b_lds_gemm_window] =
|
||||
Base::GetBWindows(b_dram_block_window_tmp, b_lds_block, b_lds_load_tile_distr);
|
||||
|
||||
// Block GEMM
|
||||
auto block_gemm = BlockGemm();
|
||||
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_gemm.MakeCBlockTile();
|
||||
|
||||
// prefetch
|
||||
// global read 0
|
||||
// Load tile — during value loading, an elementwise function is executed for each A0,
|
||||
// A1, … AN. The values A0, A1, … AN are read by the same thread.
|
||||
auto elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
// Load tile — during value loading, an elementwise function is executed for each B0,
|
||||
// B1, … BN. The values B0, B1, … BN are read by the same thread.
|
||||
auto elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
{
|
||||
// move to 1
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
// LDS write 0
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
// global read 1
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
|
||||
// LDS write 0
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
// global read 1
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
}
|
||||
|
||||
index_t iCounter = num_loop - 2;
|
||||
|
||||
do
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
// GEMM i
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// move to i + 2
|
||||
move_tile_window(as_copy_dram_window, {0, kKPerBlock});
|
||||
move_tile_window(bs_copy_dram_window, {0, kKPerBlock});
|
||||
|
||||
// LDS write i + 1
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
// global read i + 2
|
||||
elementwise_As_res =
|
||||
load_tile_with_elementwise(as_copy_dram_window, a_element_func);
|
||||
|
||||
// LDS write i + 1
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
// global read i + 2
|
||||
elementwise_Bs_res =
|
||||
load_tile_with_elementwise(bs_copy_dram_window, b_element_func);
|
||||
|
||||
iCounter--;
|
||||
|
||||
} while(iCounter > 0);
|
||||
|
||||
// tail
|
||||
{
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
// GEMM num_loop - 2
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// LDS write num_loop - 1
|
||||
store_tile(a_copy_lds_window, elementwise_As_res);
|
||||
|
||||
store_tile(b_copy_lds_window, elementwise_Bs_res);
|
||||
|
||||
block_sync_lds();
|
||||
block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
|
||||
// GEMM num_loop - 1
|
||||
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AsDramBlockWindowTmp,
|
||||
typename BsDramBlockWindowTmp,
|
||||
typename std::enable_if_t<is_detected<is_tuple, AsDramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BsDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const AsDramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BsDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return PipelineImpl{}.operator()(
|
||||
a_dram_block_window_tmp,
|
||||
[](auto& e, const ADataType & a) { e = a; },
|
||||
b_dram_block_window_tmp,
|
||||
[](auto& e, const BDataType & b) { e = b; },
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BDramBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BDramBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BDramBlockWindowTmp& b_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
return operator()(ck_tile::make_tuple(a_dram_block_window_tmp),
|
||||
ck_tile::make_tuple(b_dram_block_window_tmp),
|
||||
num_loop,
|
||||
p_smem);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,17 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// Default policy for GemmPipelineAGmemBGmemCRegV2
|
||||
// Default policy class should not be templated, put template on member functions instead
|
||||
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
|
||||
// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
|
||||
// GemmPipelineAGmemBGmemCRegV1DefaultPolicy
|
||||
using GemmPipelineAGmemBGmemCRegV2DefaultPolicy = GemmPipelineAGmemBGmemCRegV1DefaultPolicy;
|
||||
|
||||
} // namespace ck_tile
|
||||
462
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
Normal file
462
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
Normal file
@@ -0,0 +1,462 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename EDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename ComputeDataType_ = AsDataType_,
|
||||
typename AElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
struct GemmPipelineProblemBase
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
|
||||
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using AElementWise = remove_cvref_t<AElementWise_>;
|
||||
using BElementWise = remove_cvref_t<BElementWise_>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Traits::AsLayout>;
|
||||
using BsLayout = remove_cvref_t<typename Traits::BsLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool ComputeDataTypeIsTuple = is_detected<is_tuple, ComputeDataType_>::value;
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
static constexpr bool ALayoutIsTuple = is_detected<is_tuple, AsLayout>::value;
|
||||
static constexpr bool BLayoutIsTuple = is_detected<is_tuple, BsLayout>::value;
|
||||
|
||||
using ComputeDataTypeTuple = std::conditional_t<ComputeDataTypeIsTuple,
|
||||
remove_cvref_t<ComputeDataType_>,
|
||||
remove_cvref_t<tuple<ComputeDataType_>>>;
|
||||
using AsLayoutTuple = std::
|
||||
conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
|
||||
using BsLayoutTuple = std::
|
||||
conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ComputeDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, ComputeDataTypeTuple>>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayoutTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayoutTuple>>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kPadM = Traits::kPadM;
|
||||
static constexpr bool kPadN = Traits::kPadN;
|
||||
static constexpr bool kPadK = Traits::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
|
||||
// In the base situation, the Preshuffle setting should be false.
|
||||
static constexpr bool Preshuffle = false;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_problem",
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
|
||||
? pixels_per_thread
|
||||
: PackedSize * VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return PackedSize * VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
|
||||
? pixels_per_thread
|
||||
: PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
|
||||
{
|
||||
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
|
||||
constexpr index_t M0 = get_warp_size() / N2;
|
||||
constexpr index_t M1 = BlockGemmShape::kM / M0;
|
||||
|
||||
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = BlockGemmShape::kN / N0;
|
||||
|
||||
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeA = []() {
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeA_;
|
||||
}
|
||||
else if constexpr(std::is_same_v<AsLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentA();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadM ? 1 : GetAlignmentA();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr index_t VectorSizeB = []() {
|
||||
if constexpr(FixedVectorSize)
|
||||
{
|
||||
return VectorSizeB_;
|
||||
}
|
||||
else if constexpr(std::is_same_v<BsLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentB();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentB();
|
||||
}
|
||||
}();
|
||||
static constexpr index_t VectorSizeC = []() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentC();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadM ? 1 : GetAlignmentC();
|
||||
}
|
||||
}();
|
||||
};
|
||||
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename EDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
typename AElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename ComputeDataType_ = AsDataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
using GemmPipelineProblem = GemmPipelineProblemBase<AsDataType_,
|
||||
BsDataType_,
|
||||
EDataType_,
|
||||
BlockGemmShape_,
|
||||
Traits_,
|
||||
ComputeDataType_,
|
||||
AElementWise_,
|
||||
BElementWise_,
|
||||
FixedVectorSize_,
|
||||
VectorSizeA_,
|
||||
VectorSizeB_>;
|
||||
|
||||
template <typename AsDataType_,
|
||||
typename BsDataType_,
|
||||
typename EDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
typename AElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename BElementWise_ = ck_tile::element_wise::PassThrough,
|
||||
typename ComputeDataType_ = AsDataType_,
|
||||
bool FixedVectorSize_ = false,
|
||||
index_t VectorSizeA_ = 1,
|
||||
index_t VectorSizeB_ = 1>
|
||||
struct UniversalGemmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using AsDataType = remove_cvref_t<AsDataType_>;
|
||||
using BsDataType = remove_cvref_t<BsDataType_>;
|
||||
using CDataType = remove_cvref_t<EDataType_>; // actually AccDataType
|
||||
using AElementWise = remove_cvref_t<AElementWise_>;
|
||||
using BElementWise = remove_cvref_t<BElementWise_>;
|
||||
|
||||
static constexpr bool FixedVectorSize = FixedVectorSize_;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Traits::AsLayout>;
|
||||
using BsLayout = remove_cvref_t<typename Traits::BsLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool ComputeDataTypeIsTuple = is_detected<is_tuple, ComputeDataType_>::value;
|
||||
static constexpr bool ADataTypeIsTuple = is_detected<is_tuple, AsDataType>::value;
|
||||
static constexpr bool BDataTypeIsTuple = is_detected<is_tuple, BsDataType>::value;
|
||||
|
||||
static constexpr bool ALayoutIsTuple = is_detected<is_tuple, AsLayout>::value;
|
||||
static constexpr bool BLayoutIsTuple = is_detected<is_tuple, BsLayout>::value;
|
||||
|
||||
using ComputeDataTypeTuple = std::conditional_t<ComputeDataTypeIsTuple,
|
||||
remove_cvref_t<ComputeDataType_>,
|
||||
remove_cvref_t<tuple<ComputeDataType_>>>;
|
||||
using AsLayoutTuple = std::
|
||||
conditional_t<ALayoutIsTuple, remove_cvref_t<AsLayout>, remove_cvref_t<tuple<AsLayout>>>;
|
||||
using BsLayoutTuple = std::
|
||||
conditional_t<BLayoutIsTuple, remove_cvref_t<BsLayout>, remove_cvref_t<tuple<BsLayout>>>;
|
||||
|
||||
using AsDataTypeTuple = std::conditional_t<ADataTypeIsTuple,
|
||||
remove_cvref_t<AsDataType>,
|
||||
remove_cvref_t<tuple<AsDataType>>>;
|
||||
|
||||
using BsDataTypeTuple = std::conditional_t<BDataTypeIsTuple,
|
||||
remove_cvref_t<BsDataType>,
|
||||
remove_cvref_t<tuple<BsDataType>>>;
|
||||
|
||||
using ComputeDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, ComputeDataTypeTuple>>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataTypeTuple>>;
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayoutTuple>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<number<0>{}, BsDataTypeTuple>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayoutTuple>>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kPadM = Traits::kPadM;
|
||||
static constexpr bool kPadN = Traits::kPadN;
|
||||
static constexpr bool kPadK = Traits::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
static constexpr auto Scheduler = Scheduler_;
|
||||
static constexpr bool Preshuffle = Traits::Preshuffle;
|
||||
|
||||
static constexpr index_t VectorSizeA = VectorSizeA_;
|
||||
static constexpr index_t VectorSizeB = VectorSizeB_;
|
||||
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_problem",
|
||||
concat('x', kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler,
|
||||
"NumWaveGroups",
|
||||
NumWaveGroups,
|
||||
"DoubleSmemBuffer",
|
||||
DoubleSmemBuffer
|
||||
);
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename CDataType_,
|
||||
typename BlockGemmShape_,
|
||||
typename Traits_,
|
||||
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
|
||||
bool HasHotLoop_ = true,
|
||||
TailNumber TailNum_ = TailNumber::Full,
|
||||
amd_buffer_coherence_enum BMemNTType_ = amd_buffer_coherence_enum::coherence_default,
|
||||
bool BPreShufflePermute_ = false,
|
||||
typename ComputeDataType_ = ADataType_>
|
||||
struct FlatmmPipelineProblem
|
||||
{
|
||||
using Traits = remove_cvref_t<Traits_>;
|
||||
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CDataType = remove_cvref_t<CDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
|
||||
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
|
||||
|
||||
using ALayout = remove_cvref_t<typename Traits::AsLayout>;
|
||||
using BLayout = remove_cvref_t<typename Traits::BsLayout>;
|
||||
using CLayout = remove_cvref_t<typename Traits::CLayout>;
|
||||
|
||||
static constexpr bool TransposeC = Traits::TransposeC;
|
||||
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
|
||||
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
|
||||
|
||||
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
|
||||
|
||||
static constexpr bool kPadM = Traits::kPadM;
|
||||
static constexpr bool kPadN = Traits::kPadN;
|
||||
static constexpr bool kPadK = Traits::kPadK;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
|
||||
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
|
||||
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
|
||||
|
||||
static constexpr auto HasHotLoop = HasHotLoop_;
|
||||
static constexpr auto TailNum = TailNum_;
|
||||
|
||||
static constexpr auto BMemNTType = BMemNTType_;
|
||||
static constexpr bool BPreShufflePermute = BPreShufflePermute_;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "gemm_problem",
|
||||
concat('x', VectorLoadSize, kBlockSize),
|
||||
concat('x', kPadM, kPadN, kPadK),
|
||||
Scheduler);
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
|
||||
? pixels_per_thread
|
||||
: PackedSize * VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return VectorLoadSize / sizeof(ADataType);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
|
||||
{
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t pixels_per_thread =
|
||||
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
|
||||
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
|
||||
? pixels_per_thread
|
||||
: PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
else
|
||||
{
|
||||
return PackedSize * VectorLoadSize / sizeof(BDataType);
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
|
||||
{
|
||||
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t N1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
|
||||
constexpr index_t M0 = get_warp_size() / N2;
|
||||
constexpr index_t M1 = BlockGemmShape::kM / M0;
|
||||
|
||||
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M1 = kBlockSize / get_warp_size();
|
||||
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
|
||||
constexpr index_t N0 = get_warp_size() / M2;
|
||||
constexpr index_t N1 = BlockGemmShape::kN / N0;
|
||||
|
||||
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
|
||||
}
|
||||
}
|
||||
|
||||
static constexpr index_t VectorSizeA = []() {
|
||||
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentA();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadM ? 1 : GetAlignmentA();
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr index_t VectorSizeB = []() {
|
||||
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentB();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadK ? 1 : GetAlignmentB();
|
||||
}
|
||||
}();
|
||||
static constexpr index_t VectorSizeC = []() {
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return kPadN ? 1 : GetAlignmentC();
|
||||
}
|
||||
else
|
||||
{
|
||||
return kPadM ? 1 : GetAlignmentC();
|
||||
}
|
||||
}();
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
22
include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp
Normal file
22
include/ck_tile/ops/gemm/pipeline/gemm_pipelines.hpp
Normal file
@@ -0,0 +1,22 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
enum struct GemmPipeline
|
||||
{
|
||||
COMPUTE_ASYNC,
|
||||
COMPUTE_V3,
|
||||
COMPUTE_V4,
|
||||
COMPUTE_V5,
|
||||
COMPUTE_V6,
|
||||
MEMORY,
|
||||
BASIC_V1,
|
||||
BASIC_V2,
|
||||
PRESHUFFLE_V2,
|
||||
BASIC_ASYNC_V1
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,942 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct has_a_tile_access_pattern : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_a_tile_access_pattern<T, std::void_t<decltype(T::ATileAccessPattern)>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T, typename = void>
|
||||
struct has_b_tile_access_pattern : std::false_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct has_b_tile_access_pattern<T, std::void_t<decltype(T::BTileAccessPattern)>> : std::true_type
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Derived>
|
||||
struct UniversalGemmBasePolicy
|
||||
{
|
||||
#if defined(__gfx950__)
|
||||
// The combination of pk_int4_t and transposed loading causes numerical errors.
|
||||
// Therefore do not use transposed loading in this case.
|
||||
// Also, transpose load (ds_read_tr) requires specific tile distribution patterns
|
||||
// that only work for certain K warp tile sizes based on data type size:
|
||||
// - For 1-byte types (fp8/bf8): K warp tile <= 64
|
||||
// - For 2-byte types (fp16/bf16): K warp tile <= 32
|
||||
template <typename Problem>
|
||||
static constexpr bool is_a_load_tr = []() {
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
// Max K warp tile for transpose load based on data type size
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(ADataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<ADataType, float>)
|
||||
return false;
|
||||
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<remove_cvref_t<typename Problem::ALayout>,
|
||||
tensor_layout::gemm::ColumnMajor>;
|
||||
}();
|
||||
|
||||
template <typename Problem>
|
||||
static constexpr bool is_b_load_tr = []() {
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
constexpr index_t kKWarpTile = WarpTile::at(number<2>{});
|
||||
// Max K warp tile for transpose load based on data type size
|
||||
constexpr index_t kMaxKWarpTile = (sizeof(BDataType) == 1) ? 64 : 32;
|
||||
if constexpr(std::is_same_v<BDataType, float>)
|
||||
return false;
|
||||
else if constexpr(std::is_same_v<BDataType, pk_int4_t>)
|
||||
return false;
|
||||
else if constexpr(kKWarpTile > kMaxKWarpTile)
|
||||
return false;
|
||||
else
|
||||
return std::is_same_v<remove_cvref_t<typename Problem::BLayout>,
|
||||
tensor_layout::gemm::RowMajor>;
|
||||
}();
|
||||
#else
|
||||
template <typename Problem>
|
||||
static constexpr bool is_a_load_tr = false;
|
||||
template <typename Problem>
|
||||
static constexpr bool is_b_load_tr = false;
|
||||
#endif
|
||||
|
||||
template <typename T>
|
||||
using has_bcastpolicy_type = decltype(T::BCastPolicy);
|
||||
|
||||
template <typename Problem>
|
||||
static constexpr bool IsBCastPolicyBeforeLDSWrite_v = [] {
|
||||
if constexpr(is_detected<has_bcastpolicy_type, Problem>{})
|
||||
{
|
||||
return Problem::BCastPolicy == CastPolicy::BeforeLDSWrite;
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}();
|
||||
|
||||
static constexpr auto I0 = number<0>{};
|
||||
static constexpr auto I1 = number<1>{};
|
||||
static constexpr auto I2 = number<2>{};
|
||||
|
||||
// Default tile access patterns
|
||||
static constexpr auto DefaultATileAccessPattern = tile_distribution_pattern::thread_raked;
|
||||
static constexpr auto DefaultBTileAccessPattern = tile_distribution_pattern::thread_raked;
|
||||
|
||||
static constexpr auto getATileAccessPattern()
|
||||
{
|
||||
if constexpr(has_a_tile_access_pattern<Derived>::value)
|
||||
return Derived::ATileAccessPattern;
|
||||
else
|
||||
return DefaultATileAccessPattern;
|
||||
}
|
||||
|
||||
static constexpr auto getBTileAccessPattern()
|
||||
{
|
||||
if constexpr(has_b_tile_access_pattern<Derived>::value)
|
||||
return Derived::BTileAccessPattern;
|
||||
else
|
||||
return DefaultBTileAccessPattern;
|
||||
}
|
||||
|
||||
template <typename Problem,
|
||||
typename OverrideADataType = remove_cvref_t<typename Problem::ADataType>>
|
||||
CK_TILE_DEVICE static constexpr auto MakeALdsBlockDescriptor()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = OverrideADataType;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(is_a_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better lds descriptor for performance
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<KPerBlock>{}, number<MPerBlock>{}),
|
||||
make_tuple(number<MPerBlock>{}, number<1>{}),
|
||||
number<MPerBlock>{},
|
||||
number<1>{});
|
||||
return a_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Only use this ColumnMajor layout for Wave64 mode (gfx9)
|
||||
constexpr auto Wave64 = get_warp_size() == 64;
|
||||
if constexpr(Wave64 &&
|
||||
std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
// kfold and mpair dimension is not always required.
|
||||
// more dimension in merge_transform increase the difficulty of generating immarg
|
||||
// offset for compiler.
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
getATileAccessPattern()>;
|
||||
// AK1: the shuffled tile dstr has shape <X1, Y2>, use Y2 as AK1
|
||||
constexpr auto AK1 = number<TileEncodingPattern::Y2>{};
|
||||
constexpr auto AK0 = number<KPerBlock / AK1>{};
|
||||
// How the M dimension is split across threads
|
||||
constexpr auto M0 = TileEncodingPattern::X0; // # of threads in M dim
|
||||
constexpr auto M1 = number<MPerBlock / M0>{};
|
||||
|
||||
// Get the warp tile size
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
constexpr auto MPerXdl = number<WarpTile::at(I0)>{};
|
||||
|
||||
// Number of threads covering K dimension
|
||||
constexpr auto KThreadWrite = TileEncodingPattern::Y0 * TileEncodingPattern::Y1;
|
||||
constexpr auto K0PerThreadWrite = AK0 / KThreadWrite;
|
||||
constexpr auto KThreadRead = get_warp_size() / MPerXdl;
|
||||
constexpr auto K0PerThreadRead = AK0 / KThreadRead;
|
||||
|
||||
// check if we exceed all LDS banks
|
||||
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_dwords_per_128b();
|
||||
constexpr auto kfold = (AK1 * M0 * sizeof(ADataType) > LdsBanksWidth)
|
||||
? 1
|
||||
: LdsBanksWidth / (AK1 * M0 * sizeof(ADataType));
|
||||
constexpr auto KThreadReadPerm =
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
|
||||
// 1<=mpair<=n0
|
||||
constexpr auto mpair =
|
||||
(AK1 * MPerXdl * sizeof(ADataType) > LdsBanksWidth)
|
||||
? 1
|
||||
: ((LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType))) > M0
|
||||
? M0
|
||||
: LdsBanksWidth / (AK1 * MPerXdl * sizeof(ADataType)));
|
||||
|
||||
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KThreadReadPerm * M1>{},
|
||||
number<kfold * M0 / mpair>{},
|
||||
number<mpair>{},
|
||||
AK1),
|
||||
AK1);
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc,
|
||||
make_tuple(make_pass_through_transform(
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(make_tuple(number<KThreadReadPerm * M1>{},
|
||||
number<kfold * M0 / mpair>{})),
|
||||
make_pass_through_transform(number<mpair>{}),
|
||||
make_pass_through_transform(AK1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2, 3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2, 3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<M1>{})),
|
||||
make_unmerge_transform(make_tuple(number<kfold>{}, number<M0 / mpair>{})),
|
||||
make_pass_through_transform(number<mpair>{}),
|
||||
make_pass_through_transform(AK1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<0, 3>{},
|
||||
sequence<4, 5>{},
|
||||
sequence<6>{},
|
||||
sequence<7>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_ak0_m_ak1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_unmerged,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
AK1)),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<M0 / mpair>{}, number<mpair>{}, number<M1>{}))),
|
||||
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return a_lds_block_desc_ak0_m_ak1;
|
||||
}
|
||||
else // A is in RowMajor
|
||||
{
|
||||
constexpr index_t KPack = Derived::template GetSmemPackA<Problem>();
|
||||
constexpr auto DataTypeSize = sizeof(ADataType);
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto MLdsLayer =
|
||||
max(MinLdsLayer,
|
||||
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
|
||||
|
||||
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
|
||||
number<MPerBlock / MLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
a_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer * RowMul>{},
|
||||
number<KPerBlock / KPack * MLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
|
||||
a_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(
|
||||
make_tuple(number<MLdsLayer>{}, number<KPerBlock / KPack>{})),
|
||||
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto a_lds_block_desc = transform_tensor_descriptor(
|
||||
a_lds_block_desc_xk0_mnldslayer_mn_xk1,
|
||||
make_tuple(make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return a_lds_block_desc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Create LDS block descriptor for B tensor.
|
||||
*
|
||||
* @tparam Problem Gemm pipeline problem.
|
||||
* @return B tensor LDS block descriptor.
|
||||
*/
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
|
||||
{
|
||||
using BLayout = remove_cvref_t<typename Problem::BLayout>;
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(is_b_load_tr<Problem>)
|
||||
{
|
||||
// TODO: better lds descriptor for performance
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( //
|
||||
make_tuple(number<KPerBlock>{}, number<NPerBlock>{}),
|
||||
make_tuple(number<NPerBlock>{}, number<1>{}),
|
||||
number<NPerBlock>{},
|
||||
number<1>{});
|
||||
return b_lds_block_desc_0;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Only use this RowMajor layout for Wave64 mode (gfx9)
|
||||
constexpr auto Wave64 = get_warp_size() == 64;
|
||||
if constexpr(Wave64 && std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern()>;
|
||||
// BK1: the shuffled tile dstr has shape <X1, Y2>, use Y2 as BK1
|
||||
constexpr auto BK1 = number<TileEncodingPattern::Y2>{};
|
||||
constexpr auto BK0 = number<KPerBlock / BK1>{};
|
||||
|
||||
// How threads access data on N dim
|
||||
constexpr auto N0 = TileEncodingPattern::X0; // # of threads in N dim
|
||||
constexpr auto N1 = number<NPerBlock / N0>{};
|
||||
|
||||
// Get NPerXdl, the warp tile size
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
constexpr auto NPerXdl = number<WarpTile::at(I1)>{};
|
||||
|
||||
// Number of threads covering K dimension
|
||||
constexpr auto KThreadWrite = TileEncodingPattern::Y0 * TileEncodingPattern::Y1;
|
||||
constexpr auto K0PerThreadWrite = BK0 / KThreadWrite;
|
||||
constexpr auto KThreadRead = get_warp_size() / NPerXdl;
|
||||
constexpr auto K0PerThreadRead = BK0 / KThreadRead;
|
||||
|
||||
// check if we exceed all LDS banks
|
||||
constexpr auto LdsBanksWidth = get_n_lds_banks() * get_n_dwords_per_128b();
|
||||
constexpr auto kfold = (BK1 * N0 * sizeof(BDataType) > LdsBanksWidth)
|
||||
? 1
|
||||
: LdsBanksWidth / (BK1 * N0 * sizeof(BDataType));
|
||||
constexpr auto KThreadReadPerm =
|
||||
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
|
||||
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
|
||||
: KThreadRead;
|
||||
|
||||
// 1<=npair<=n0
|
||||
constexpr auto npair =
|
||||
(BK1 * NPerXdl * sizeof(BDataType) > LdsBanksWidth)
|
||||
? 1
|
||||
: ((LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType))) > N0
|
||||
? N0
|
||||
: LdsBanksWidth / (BK1 * NPerXdl * sizeof(BDataType)));
|
||||
|
||||
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
|
||||
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
number<KThreadReadPerm * N1>{},
|
||||
number<kfold * N0 / npair>{},
|
||||
number<npair>{},
|
||||
BK1),
|
||||
BK1);
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc,
|
||||
make_tuple(make_pass_through_transform(
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_xor_transform(make_tuple(number<KThreadReadPerm * N1>{},
|
||||
number<kfold * N0 / npair>{})),
|
||||
make_pass_through_transform(number<npair>{}),
|
||||
make_pass_through_transform(BK1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2, 3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2, 3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(
|
||||
make_pass_through_transform(
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{}),
|
||||
make_pass_through_transform(number<K0PerThreadWrite>{}),
|
||||
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
|
||||
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
|
||||
make_pass_through_transform(number<npair>{}),
|
||||
make_pass_through_transform(BK1)),
|
||||
make_tuple(sequence<0>{},
|
||||
sequence<1>{},
|
||||
sequence<2>{},
|
||||
sequence<3>{},
|
||||
sequence<4>{},
|
||||
sequence<5>{}),
|
||||
make_tuple(
|
||||
sequence<1>{}, // 0: K0PerThreadWrite
|
||||
sequence<2>{}, // 1: KThreadReadPerm
|
||||
sequence<0, 3>{}, // 2: KThreadWrite / kfold / KThreadReadPerm, 3: N1
|
||||
sequence<4, 5>{}, // 4: kfold, 5: N0 / npair
|
||||
sequence<6>{}, // 6: npair
|
||||
sequence<7>{})); // 7: BK1
|
||||
|
||||
constexpr auto b_lds_block_desc_nk = transform_tensor_descriptor(
|
||||
b_lds_block_desc_unmerged,
|
||||
make_tuple(make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<KThreadReadPerm>{},
|
||||
number<KThreadWrite / kfold / KThreadReadPerm>{},
|
||||
number<kfold>{},
|
||||
number<K0PerThreadWrite>{},
|
||||
BK1)),
|
||||
make_merge_transform_v3_division_mod(make_tuple(
|
||||
number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
|
||||
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
|
||||
make_tuple(sequence<1>{}, sequence<0>{}));
|
||||
|
||||
return b_lds_block_desc_nk;
|
||||
}
|
||||
else // B is Column Major
|
||||
{
|
||||
constexpr index_t KPack = GetSmemPackB<Problem>();
|
||||
constexpr auto BK0 = number<KPerBlock / KPack>{};
|
||||
constexpr auto DataTypeSize = sizeof(BDataType);
|
||||
constexpr uint64_t MinLdsLayer = 1ULL;
|
||||
constexpr auto NLdsLayer =
|
||||
max(MinLdsLayer,
|
||||
get_n_lds_banks() * get_n_dwords_per_128b() / KPerBlock / DataTypeSize);
|
||||
|
||||
constexpr index_t NBanks = get_n_lds_banks();
|
||||
static_assert(NBanks == 32 || NBanks == 64, "Unexpected LDS bank count");
|
||||
constexpr index_t RowMul = (NBanks == 64) ? 2 : 1;
|
||||
|
||||
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
|
||||
make_tuple(BK0 * number<NLdsLayer>{},
|
||||
number<NPerBlock / NLdsLayer>{},
|
||||
number<KPack>{}),
|
||||
make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
|
||||
number<KPack>{},
|
||||
number<1>{});
|
||||
|
||||
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
|
||||
b_lds_block_desc_0,
|
||||
make_tuple(
|
||||
make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer * RowMul>{},
|
||||
BK0 * number<NLdsLayer>{})),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
|
||||
b_lds_block_desc_permuted,
|
||||
make_tuple(make_unmerge_transform(make_tuple(number<NLdsLayer>{}, BK0)),
|
||||
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
|
||||
make_pass_through_transform(number<KPack>{})),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
|
||||
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
|
||||
|
||||
constexpr auto b_lds_block_desc = transform_tensor_descriptor(
|
||||
b_lds_block_desc_bk0_nldslayer_n_bk1,
|
||||
make_tuple(
|
||||
make_merge_transform_v3_division_mod(
|
||||
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
|
||||
make_merge_transform_v3_division_mod(make_tuple(BK0, number<KPack>{}))),
|
||||
make_tuple(sequence<1, 0>{}, sequence<2, 3>{}),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
return b_lds_block_desc;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the maximum global memory vector load size.
|
||||
*
|
||||
* @tparam Problem The UniversalGemmPipelineProblem object.
|
||||
* @tparam DataType The tensor data type we're considering.
|
||||
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
|
||||
* @tparam XPerTile The contiguous Tile dimension size.
|
||||
* @return Maximum DRAM vector load size.
|
||||
*/
|
||||
template <typename Problem,
|
||||
typename DataType,
|
||||
index_t MNPerBlock,
|
||||
index_t XPerTile,
|
||||
bool IsWave32Host>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
|
||||
{
|
||||
constexpr index_t BlockSize = IsWave32Host ? Problem::kBlockSize / 2 : Problem::kBlockSize;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
|
||||
constexpr index_t PackedSize =
|
||||
ck_tile::numeric_traits<remove_cvref_t<DataType>>::PackedSize;
|
||||
|
||||
// Assume DataType is even!
|
||||
if constexpr(XPerTile % (PackedSize * 16 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 16 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 16 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(XPerTile % (PackedSize * 8 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 8 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 8 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 4 &&
|
||||
XPerTile % (PackedSize * 4 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 4 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 4 / sizeof(DataType));
|
||||
}
|
||||
else if constexpr(sizeof(DataType) >= PackedSize * 2 &&
|
||||
XPerTile % (PackedSize * 2 / sizeof(DataType)) == 0 &&
|
||||
elements_per_thread % (PackedSize * 2 / sizeof(DataType)) == 0)
|
||||
{
|
||||
return (PackedSize * 2 / sizeof(DataType));
|
||||
}
|
||||
else
|
||||
{
|
||||
return PackedSize;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<number<0>{}, AsLayout>>;
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<number<0>{}, AsDataType>>;
|
||||
|
||||
if constexpr(Problem::FixedVectorSize)
|
||||
{
|
||||
return Problem::VectorSizeA;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
ADataType,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
IsWave32Host>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
ADataType,
|
||||
MPerBlock,
|
||||
MPerBlock,
|
||||
IsWave32Host>();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem, bool IsWave32Host = false>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<number<0>{}, BsLayout>>;
|
||||
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
|
||||
if constexpr(Problem::FixedVectorSize)
|
||||
{
|
||||
return Problem::VectorSizeB;
|
||||
}
|
||||
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
BDataType,
|
||||
NPerBlock,
|
||||
NPerBlock,
|
||||
IsWave32Host>();
|
||||
}
|
||||
else
|
||||
{
|
||||
return GetGlobalVectorLoadSize<Problem,
|
||||
BDataType,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
IsWave32Host>();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Get the vector store size for C tensor.
|
||||
*
|
||||
* @tparam Problem - Gemm pipeline problem class.
|
||||
*
|
||||
* @note The vector store size for output C tensor would depend on multiple factors
|
||||
* like its data layout and warp gemm C transposition. In general it would
|
||||
* be the number of consecutive elements in contiguous C dimension hold by
|
||||
* single thread.
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
|
||||
using WG = typename BlockGemm::WarpGemm;
|
||||
|
||||
constexpr bool TransposeC = Problem::TransposeC;
|
||||
using CLayout = typename Problem::CLayout;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(TransposeC)
|
||||
{
|
||||
// In this case each thread has multiple consecutive elements in
|
||||
// N dimension, however consecutive threads' elements have stride.
|
||||
constexpr index_t NDimY = CWarpDstr::NDimY;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
|
||||
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
|
||||
c_warp_y_lengths.get(number<NDimY - 1>{}));
|
||||
return c_warp_y_lengths.get(number<NDimY - 1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// In this case each thread has just a single item in Ndim
|
||||
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
|
||||
}
|
||||
}
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if constexpr(TransposeC)
|
||||
{
|
||||
// In this case each thread has just a single item in Mdim
|
||||
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
|
||||
}
|
||||
else
|
||||
{
|
||||
// In this case each thread has multiple consecutive elements in
|
||||
// M dimension, however consecutive threads' elements have stride.
|
||||
constexpr index_t NDimY = CWarpDstr::NDimY;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
|
||||
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
|
||||
c_warp_y_lengths.get(number<NDimY - 1>{}));
|
||||
return c_warp_y_lengths.get(number<NDimY - 1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported CLayout!");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
|
||||
{
|
||||
return Problem::TransposeC;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize =
|
||||
Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using ALayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
|
||||
// Tile: MPerBlock X KPerBlock
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
MPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
getATileAccessPattern(),
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
// Tile: KPerBlock X MPerBlock
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
getATileAccessPattern(),
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
|
||||
{
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
// If we cast before writing to LDS, the vectorsize is defined by the A type
|
||||
// since the assumption is that A type is going to be the B LDS type
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
constexpr index_t VecLoadSize =
|
||||
IsBCastPolicyBeforeLDSWrite
|
||||
? (Problem::FixedVectorSize ? Problem::VectorSizeA : GetVectorSizeA<Problem>())
|
||||
: (Problem::FixedVectorSize ? Problem::VectorSizeB : GetVectorSizeB<Problem>());
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
// Tile: KPerBlock X NPerBlock
|
||||
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern(),
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
// Tile: NPerBlock X KPerBlock
|
||||
else
|
||||
{
|
||||
using TileEncodingPattern =
|
||||
tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern(),
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_2d_static_tile_distribution();
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::AsLayoutTuple>>>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
MPerBlock,
|
||||
VecLoadSize,
|
||||
getATileAccessPattern(),
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
|
||||
{
|
||||
using BLayout = remove_cvref_t<
|
||||
std::tuple_element_t<number<0>{}, remove_cvref_t<typename Problem::BsLayoutTuple>>>;
|
||||
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
|
||||
constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
using TileEncodingPattern = tile_distribution_encoding_pattern_2d<BlockSize,
|
||||
KPerBlock,
|
||||
NPerBlock,
|
||||
VecLoadSize,
|
||||
getBTileAccessPattern(),
|
||||
NumWaveGroups>;
|
||||
return TileEncodingPattern::make_shuffled_2d_static_tile_distribution();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackA()
|
||||
{
|
||||
using A = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(A));
|
||||
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemPackB()
|
||||
{
|
||||
using B = remove_cvref_t<typename Problem::BDataType>;
|
||||
using BlockGemm = remove_cvref_t<decltype(Derived::template GetBlockGemm<Problem>())>;
|
||||
|
||||
constexpr index_t KPack = static_cast<index_t>(BlockGemm::Traits::KPack);
|
||||
constexpr index_t VecElems = static_cast<index_t>(Problem::VectorLoadSize / sizeof(B));
|
||||
|
||||
return (KPack < VecElems) ? KPack : VecElems;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
constexpr auto APackedSize = numeric_traits<ADataType>::PackedSize;
|
||||
constexpr auto a_lds_block_desc = Derived::template MakeALdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_a = integer_least_multiple(
|
||||
a_lds_block_desc.get_element_space_size() * sizeof(ADataType) / APackedSize, 16);
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr bool IsBCastPolicyBeforeLDSWrite = IsBCastPolicyBeforeLDSWrite_v<Problem>;
|
||||
using BDataType = std::conditional_t<IsBCastPolicyBeforeLDSWrite,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::BDataType>;
|
||||
constexpr auto BPackedSize = numeric_traits<BDataType>::PackedSize;
|
||||
constexpr auto b_lds_block_desc = Derived::template MakeBLdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
b_lds_block_desc.get_element_space_size() * sizeof(BDataType) / BPackedSize, 16);
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
|
||||
return smem_size_a + smem_size_b;
|
||||
}
|
||||
};
|
||||
|
||||
// UniversalGemm Policy
|
||||
struct UniversalGemmPipelineAgBgCrPolicy
|
||||
: public UniversalGemmBasePolicy<UniversalGemmPipelineAgBgCrPolicy>
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
constexpr index_t vector_size =
|
||||
DS_READ_TR_SIZE() / sizeof(typename Problem::ComputeDataType);
|
||||
constexpr index_t thread_elements = WarpTile::at(I1) * WarpTile::at(I2) / get_warp_size();
|
||||
constexpr auto wg_attr_num_access =
|
||||
!(is_a_load_tr<Problem> || is_b_load_tr<Problem>) ? WGAttrNumAccessEnum::Single
|
||||
: vector_size == thread_elements ? WGAttrNumAccessEnum::Single
|
||||
: vector_size * 2 == thread_elements ? WGAttrNumAccessEnum::Double
|
||||
: vector_size * 4 == thread_elements ? WGAttrNumAccessEnum::Quad
|
||||
: WGAttrNumAccessEnum::Invalid;
|
||||
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using BDataType = remove_cvref_t<typename Problem::BDataType>;
|
||||
using ATypeToUse =
|
||||
std::conditional_t<std::is_same_v<ADataType, pk_int4_t>, BDataType, ADataType>;
|
||||
using BTypeToUse = std::conditional_t<std::is_same_v<BDataType, pk_int4_t> ||
|
||||
std::is_same_v<BDataType, pk_fp4_t> ||
|
||||
sizeof(BDataType) < sizeof(ADataType),
|
||||
ADataType,
|
||||
BDataType>;
|
||||
|
||||
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
Problem::UseStructuredSparsity,
|
||||
wg_attr_num_access>;
|
||||
|
||||
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<ATypeToUse,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
69
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
Normal file
69
include/ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockTile_,
|
||||
typename BlockWarps_,
|
||||
typename WarpTile_,
|
||||
bool PermuteA_ = false,
|
||||
bool PermuteB_ = false>
|
||||
struct TileGemmShape
|
||||
{
|
||||
using BlockTile = remove_cvref_t<BlockTile_>;
|
||||
using BlockWarps = remove_cvref_t<BlockWarps_>;
|
||||
using WarpTile = remove_cvref_t<WarpTile_>;
|
||||
|
||||
static constexpr index_t NumWarps =
|
||||
reduce_on_sequence(BlockWarps{}, multiplies<>{}, number<1>{});
|
||||
|
||||
static constexpr index_t kM = BlockTile::at(number<0>{});
|
||||
static constexpr index_t kN = BlockTile::at(number<1>{});
|
||||
static constexpr index_t kK = BlockTile::at(number<2>{});
|
||||
|
||||
static constexpr bool PermuteA = PermuteA_;
|
||||
static constexpr bool PermuteB = PermuteB_;
|
||||
|
||||
static constexpr index_t flatNPerWarp = BlockWarps::at(number<1>{});
|
||||
static constexpr index_t flatKPerWarp = WarpTile::at(number<2>{}) * WarpTile::at(number<1>{});
|
||||
static constexpr index_t flatKPerBlock = flatKPerWarp * kK / WarpTile::at(number<2>{});
|
||||
|
||||
CK_TILE_HOST static std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "tile_gemm_shape",
|
||||
concat('x', kM, kN, kK, NumWarps),
|
||||
concat('x', BlockWarps::at(number<0>{}), BlockWarps::at(number<1>{}), BlockWarps::at(number<2>{})),
|
||||
concat('x', (WarpTile::at(number<0>{})), WarpTile::at(number<1>{}), WarpTile::at(number<2>{})));
|
||||
// clang-format on
|
||||
}
|
||||
};
|
||||
|
||||
template <typename PrecType, index_t M_Warp_Tile, bool IsFlatMM = false>
|
||||
constexpr index_t get_k_warp_tile()
|
||||
{
|
||||
#if CK_TILE_USE_WMMA
|
||||
return 16;
|
||||
#else
|
||||
#if defined(CK_GFX950_SUPPORT)
|
||||
constexpr bool is_8bit_float =
|
||||
std::is_same_v<PrecType, fp8_t> || std::is_same_v<PrecType, bf8_t>;
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return is_8bit_float ? 64 : 16;
|
||||
else
|
||||
return is_8bit_float ? 128 : 32;
|
||||
#else
|
||||
if constexpr(M_Warp_Tile == 32)
|
||||
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 16 : 32;
|
||||
else
|
||||
return (sizeof(PrecType) == 2 || IsFlatMM == false) ? 32 : 64;
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
87
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
Normal file
87
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
Normal file
@@ -0,0 +1,87 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
typename AsLayout_,
|
||||
typename BsLayout_,
|
||||
typename CLayout_,
|
||||
index_t NumWaveGroups_ = 1>
|
||||
struct TileGemmTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
|
||||
// TODO this can't be hardcoded here! Should be in policy!
|
||||
static constexpr int _VectorSize = 16;
|
||||
|
||||
using AsLayout = AsLayout_;
|
||||
using BsLayout = BsLayout_;
|
||||
using CLayout = CLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool DoubleSmemBuffer_,
|
||||
typename AsLayout_,
|
||||
typename BsLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false,
|
||||
bool UsePersistentKernel_ = false,
|
||||
index_t NumWaveGroups_ = 1,
|
||||
bool Preshuffle_ = false,
|
||||
int VectorSize_ = 16>
|
||||
struct TileGemmUniversalTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
static constexpr int _VectorSize = VectorSize_;
|
||||
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
|
||||
|
||||
using AsLayout = AsLayout_;
|
||||
using BsLayout = BsLayout_;
|
||||
using CLayout = CLayout_;
|
||||
static constexpr bool TransposeC = TransposeC_;
|
||||
|
||||
static constexpr bool UseStructuredSparsity = UseStructuredSparsity_;
|
||||
static constexpr bool UsePersistentKernel = UsePersistentKernel_;
|
||||
static constexpr index_t NumWaveGroups = NumWaveGroups_;
|
||||
static constexpr bool Preshuffle = Preshuffle_;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
bool DoubleSmemBuffer_,
|
||||
typename AsLayout_,
|
||||
typename BsLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false,
|
||||
bool UseStructuredSparsity_ = false>
|
||||
using PersistentTileGemmUniversalTraits = TileGemmUniversalTraits<kPadM_,
|
||||
kPadN_,
|
||||
kPadK_,
|
||||
DoubleSmemBuffer_,
|
||||
AsLayout_,
|
||||
BsLayout_,
|
||||
CLayout_,
|
||||
TransposeC_,
|
||||
UseStructuredSparsity_,
|
||||
true>;
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,365 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/core/numeric/numeric.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_wp_asmem_breg_creg.hpp"
|
||||
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
struct UniversalWeightPreshufflePipelineAgBgCrPolicy
|
||||
: public UniversalGemmBasePolicy<UniversalWeightPreshufflePipelineAgBgCrPolicy>
|
||||
{
|
||||
using BasePolicy = UniversalGemmBasePolicy<UniversalWeightPreshufflePipelineAgBgCrPolicy>;
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
|
||||
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
|
||||
{
|
||||
return Problem::VectorLoadSize / sizeof(typename Problem::ADataType);
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetKBPerLoad()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t scale = 4;
|
||||
#else
|
||||
constexpr index_t scale = get_warp_size() == 32 ? 2 : 1;
|
||||
#endif
|
||||
if constexpr(TileShape::WarpTile::at(I1) == 32)
|
||||
{
|
||||
return TileShape::WarpTile::at(I2) * scale / 2;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(TileShape::WarpTile::at(I1) == 16);
|
||||
return TileShape::WarpTile::at(I2) * scale / 4;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
|
||||
{
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
|
||||
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = MPerBlock / M1;
|
||||
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
constexpr index_t KPack = GetSmemPackA<Problem>();
|
||||
static_assert(KPack % K3 == 0);
|
||||
constexpr index_t K2 = KPack / K3;
|
||||
if constexpr(get_warp_size() >= (K2 * M0))
|
||||
{
|
||||
constexpr index_t K1 = get_warp_size() / (K2 * M0);
|
||||
constexpr index_t K0 = BlockSize / get_warp_size();
|
||||
static_assert(KPerBlock == K0 * K1 * K2 * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * M0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
|
||||
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<2, 1>,
|
||||
sequence<3, 1>>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t K0 = KPerBlock / K1;
|
||||
constexpr index_t M2 = get_warp_size() / K0;
|
||||
// coalesce reading for each blocks
|
||||
if constexpr(get_warp_size() % (M2 * K0) == 0)
|
||||
{
|
||||
constexpr index_t M1 = BlockSize / get_warp_size();
|
||||
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
|
||||
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
|
||||
constexpr index_t M0 = MPerBlock / (M2 * M1);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M2, M1 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t M0 = BlockSize / get_warp_size();
|
||||
constexpr index_t M1 = MPerBlock / (M2 * M0);
|
||||
static_assert(M0 * M1 * M2 == MPerBlock,
|
||||
"Incorrect M0, M1, M2 configuration! "
|
||||
"M0, M1, M2 must cover whole MPerBlock!");
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeBFlatDramTileDistribution()
|
||||
{
|
||||
using TileShape = typename Problem::BlockGemmShape;
|
||||
using BDataType = typename Problem::BDataType;
|
||||
|
||||
constexpr index_t kNPerBlock = TileShape::kN;
|
||||
constexpr index_t kKPerBlock = TileShape::kK;
|
||||
constexpr index_t NIterPerWarp =
|
||||
kNPerBlock / TileShape::BlockWarps::at(I1) / TileShape::WarpTile::at(I1);
|
||||
constexpr index_t KIterPerWarp = kKPerBlock / TileShape::WarpTile::at(I2);
|
||||
|
||||
constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t WaveNum = BlockSize / WaveSize;
|
||||
|
||||
#if defined(__gfx11__)
|
||||
constexpr index_t KRepeatInWave = 2;
|
||||
#else
|
||||
constexpr index_t KRepeatInWave = 1;
|
||||
#endif
|
||||
constexpr index_t KBPerLoad = min(
|
||||
GetKBPerLoad<Problem>(), KRepeatInWave * 16 / static_cast<index_t>(sizeof(BDataType)));
|
||||
constexpr index_t KThdPerWave = WaveSize / KRepeatInWave; // threads cnt in K dim
|
||||
constexpr index_t KWavePerBlk = 1;
|
||||
constexpr index_t KRepeat = KIterPerWarp;
|
||||
constexpr index_t KAccess = GetKBPerLoad<Problem>() / KBPerLoad;
|
||||
static_assert(TileShape::flatKPerWarp == KAccess * KThdPerWave * KBPerLoad, "wrong");
|
||||
|
||||
constexpr index_t NBPerLoad = 1;
|
||||
constexpr index_t NThdPerWave = 1;
|
||||
constexpr index_t NWavePerBlk = TileShape::BlockWarps::at(number<1>{}); // N_Warp
|
||||
constexpr index_t NRepeat = NIterPerWarp;
|
||||
|
||||
constexpr index_t WaveRepeat = WaveNum / TileShape::flatNPerWarp;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<
|
||||
sequence<WaveRepeat, KRepeatInWave>, // ?
|
||||
tuple<sequence<NRepeat, NWavePerBlk, NThdPerWave, NBPerLoad>, // second direction
|
||||
sequence<KRepeat, KAccess, KWavePerBlk, KThdPerWave, KBPerLoad>>,
|
||||
// wave in blk, // thd in wave
|
||||
// <M, K> // <M, K>
|
||||
tuple<sequence<0, 1, 2>, sequence<0, 1, 2>>, // which direction
|
||||
tuple<sequence<0, 1, 2>, sequence<1, 2, 3>>, // which index
|
||||
// <repeat, vec_load>
|
||||
sequence<1, 2, 1, 2, 2>,
|
||||
sequence<0, 0, 3, 1, 4>>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
|
||||
{
|
||||
using ALayout = remove_cvref_t<typename Problem::ALayout>;
|
||||
using ADataType = remove_cvref_t<typename Problem::ADataType>;
|
||||
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
|
||||
constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
|
||||
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
|
||||
|
||||
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
constexpr index_t M0 = kMPerBlock / M1;
|
||||
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
|
||||
static_assert(total_pixels % M1 == 0);
|
||||
constexpr index_t K3 = total_pixels / M1;
|
||||
constexpr index_t kKPack = GetSmemPackA<Problem>();
|
||||
static_assert(kKPack % K3 == 0);
|
||||
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
if constexpr(warp_size >= (K2 * M0))
|
||||
{
|
||||
constexpr index_t K1 = warp_size / (K2 * M0);
|
||||
constexpr index_t K0 = kBlockSize / warp_size;
|
||||
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
|
||||
tuple<sequence<2>, sequence<2, 1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
constexpr index_t K1 = (K2 * M0) / get_warp_size();
|
||||
constexpr index_t K2_m = K2 / K1;
|
||||
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
|
||||
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
|
||||
tuple<sequence<2, 2>, sequence<1, 2>>,
|
||||
tuple<sequence<0, 1>, sequence<0, 2>>,
|
||||
sequence<1, 2>,
|
||||
sequence<1, 3>>{});
|
||||
}
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
|
||||
{
|
||||
return GetBlockWeightPreshuffle<Problem>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetBlockWeightPreshuffle()
|
||||
{
|
||||
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
|
||||
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
|
||||
|
||||
// Determine compute types to use
|
||||
// This logic defaults to A/B DataType, but if one of them is packed falls back to the other
|
||||
// If both are packed, it falls back to the explicitly defined ComputeDataType in the
|
||||
// problem It might be a good idea to use ComputeDataType anyway, but that would break how
|
||||
// this behaviour used to work
|
||||
using ATypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::ComputeDataType>;
|
||||
using BTypeToUse = mixed_prec_compute_type_from_input_t<typename Problem::BDataType,
|
||||
typename Problem::ADataType,
|
||||
typename Problem::ComputeDataType>;
|
||||
|
||||
constexpr index_t WaveSize = get_warp_size();
|
||||
constexpr index_t KLane = WarpTile::at(I2) * WarpTile::at(I0) / WaveSize;
|
||||
// When BDataType is pk_int4_t, it is internally converted to fp8 for computation.
|
||||
constexpr index_t KLaneBytes = KLane * sizeof(BTypeToUse);
|
||||
constexpr auto NumAccess = static_cast<WGAttrNumAccessEnum>(max(1, KLaneBytes / 16));
|
||||
using WarpGemm = WarpGemmDispatcher<ATypeToUse,
|
||||
BTypeToUse,
|
||||
typename Problem::CDataType,
|
||||
WarpTile::at(I0),
|
||||
WarpTile::at(I1),
|
||||
WarpTile::at(I2),
|
||||
Problem::TransposeC,
|
||||
false,
|
||||
false,
|
||||
NumAccess>;
|
||||
|
||||
using BlockWeightPreshufflePolicy =
|
||||
BlockWeightPreshuffleASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
|
||||
typename Problem::BDataType,
|
||||
typename Problem::CDataType,
|
||||
BlockWarps,
|
||||
WarpGemm>;
|
||||
return BlockWeightPreshuffleASmemBRegCReg<Problem, BlockWeightPreshufflePolicy>{};
|
||||
}
|
||||
/**
|
||||
* @brief Get the vector store size for C tensor.
|
||||
*
|
||||
* @tparam Problem - Gemm pipeline problem class.
|
||||
*
|
||||
* @note The vector store size for output C tensor would depend on multiple factors
|
||||
* like its data layout and warp gemm C transposition. In general it would
|
||||
* be the number of consecutive elements in contiguous C dimension hold by
|
||||
* single thread.
|
||||
*
|
||||
* @return The vector store size for C tensor.
|
||||
*/
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetBlockWeightPreshuffle<Problem>())>;
|
||||
using WG_ = typename BlockGemm::WarpGemm;
|
||||
|
||||
constexpr bool TransposeC = Problem::TransposeC;
|
||||
using CLayout = typename Problem::CLayout;
|
||||
using CWarpDstr = typename WG_::CWarpDstr;
|
||||
|
||||
// N is contiguous dimension
|
||||
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
|
||||
{
|
||||
if constexpr(TransposeC)
|
||||
{
|
||||
// In this case each thread has multiple consecutive elements in
|
||||
// N dimension, however consecutive threads' elements have stride.
|
||||
constexpr index_t NDimY = CWarpDstr::NDimY;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
|
||||
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
|
||||
c_warp_y_lengths.get(number<NDimY - 1>{}));
|
||||
return c_warp_y_lengths.get(number<NDimY - 1>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
// In this case each thread has just a single item in Ndim
|
||||
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
|
||||
}
|
||||
}
|
||||
// M is contiguous dimension
|
||||
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
|
||||
{
|
||||
if constexpr(TransposeC)
|
||||
{
|
||||
// In this case each thread has just a single item in Mdim
|
||||
return WG_::WarpGemmAttribute::Impl::kCNLane / WG_::kN;
|
||||
}
|
||||
else
|
||||
{
|
||||
// In this case each thread has multiple consecutive elements in
|
||||
// M dimension, however consecutive threads' elements have stride.
|
||||
constexpr index_t NDimY = CWarpDstr::NDimY;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
|
||||
static_assert(WG_::WarpGemmAttribute::Impl::kCM1PerLane ==
|
||||
c_warp_y_lengths.get(number<NDimY - 1>{}));
|
||||
return c_warp_y_lengths.get(number<NDimY - 1>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported CLayout!");
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,796 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/load_and_convert_tile.hpp"
|
||||
#include "ck_tile/host/concat.hpp"
|
||||
#include "ck_tile/ops/gemm/pipeline/wp_pipeline_agmem_bgmem_creg_base_policy.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem>
|
||||
struct BaseWeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
{
|
||||
static constexpr index_t PrefetchStages = 2;
|
||||
static constexpr index_t PrefillStages = 1;
|
||||
static constexpr index_t GlobalBufferNum = 1;
|
||||
static constexpr bool UsePersistentKernel = Problem::Traits::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr bool BlockHasHotloop(index_t num_loop)
|
||||
{
|
||||
return num_loop > PrefetchStages;
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
|
||||
{
|
||||
return num_loop % 2 == 0 ? TailNumber::Even : TailNumber::Odd;
|
||||
}
|
||||
|
||||
template <typename RunFunction>
|
||||
CK_TILE_HOST_DEVICE static auto
|
||||
TailHandler(const RunFunction& run_func, bool has_hot_loop, TailNumber tail_number)
|
||||
{
|
||||
if(has_hot_loop)
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else // Even tail number
|
||||
{
|
||||
return run_func(bool_constant<true>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(tail_number == TailNumber::Odd)
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Odd>{});
|
||||
}
|
||||
else // Even tail number
|
||||
{
|
||||
return run_func(bool_constant<false>{},
|
||||
integral_constant<TailNumber, TailNumber::Even>{});
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Problem, typename PipelinePolicy = UniversalWeightPreshufflePipelineAgBgCrPolicy>
|
||||
struct WeightPreshufflePipelineAGmemBGmemCRegV2
|
||||
: public BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>
|
||||
{
|
||||
using Base = BaseWeightPreshufflePipelineAGmemBGmemCRegV2<Problem>;
|
||||
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, PipelinePolicy>;
|
||||
|
||||
using AsDataType = remove_cvref_t<typename Problem::AsDataTypeTuple>;
|
||||
using BsDataType = remove_cvref_t<typename Problem::BsDataTypeTuple>;
|
||||
using CDataType = remove_cvref_t<typename Problem::CDataType>;
|
||||
|
||||
using AElementWise = remove_cvref_t<typename Problem::AElementWise>;
|
||||
using BElementWise = remove_cvref_t<typename Problem::BElementWise>;
|
||||
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>; // TileFlatmmShape
|
||||
|
||||
using AsLayout = remove_cvref_t<typename Problem::AsLayoutTuple>;
|
||||
using BsLayout = remove_cvref_t<typename Problem::BsLayoutTuple>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
|
||||
using ALayout = remove_cvref_t<std::tuple_element_t<0, AsLayout>>;
|
||||
using BLayout = remove_cvref_t<std::tuple_element_t<0, BsLayout>>;
|
||||
|
||||
using ADataType = remove_cvref_t<std::tuple_element_t<0, AsDataType>>;
|
||||
using BDataType = remove_cvref_t<std::tuple_element_t<0, BsDataType>>;
|
||||
|
||||
using BlockWeightPreshuffle =
|
||||
remove_cvref_t<decltype(PipelinePolicy::template GetBlockWeightPreshuffle<Problem>())>;
|
||||
|
||||
static constexpr index_t DsWritePreIssue = 3; // default 2, ds write at MIter - 2
|
||||
static constexpr index_t DsReadPreload = 2; // default 2, preload 2 ds read
|
||||
|
||||
static constexpr index_t BlockSize = Problem::kBlockSize;
|
||||
static constexpr index_t WaveSize = get_warp_size();
|
||||
|
||||
static constexpr index_t kMPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t kNPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t kKPerBlock = BlockGemmShape::kK;
|
||||
|
||||
// bogus variables to compile grouped gemm (to be removed)
|
||||
static constexpr index_t MPerBlock = BlockGemmShape::kM;
|
||||
static constexpr index_t NPerBlock = BlockGemmShape::kN;
|
||||
static constexpr index_t KPerBlock = BlockGemmShape::kK;
|
||||
|
||||
static constexpr index_t kflatKPerBlock = BlockGemmShape::flatKPerBlock;
|
||||
|
||||
static constexpr index_t flatKPerWarp = BlockGemmShape::flatKPerWarp;
|
||||
static constexpr index_t flatNPerWarp = BlockGemmShape::flatNPerWarp;
|
||||
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeA()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeA<Problem, IsWave32Host>();
|
||||
}
|
||||
template <bool IsWave32Host = false>
|
||||
static constexpr index_t GetVectorSizeB()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeB<Problem, IsWave32Host>();
|
||||
}
|
||||
|
||||
static constexpr index_t GetVectorSizeC()
|
||||
{
|
||||
return PipelinePolicy::template GetVectorSizeC<Problem>();
|
||||
}
|
||||
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool kPadK = Problem::kPadK;
|
||||
|
||||
static constexpr index_t kLdsAlignmentInBytes = 16;
|
||||
static constexpr index_t NumWaveGroups = Problem::NumWaveGroups;
|
||||
|
||||
static constexpr auto I0 = number<0>();
|
||||
static constexpr auto I1 = number<1>();
|
||||
static constexpr auto I2 = number<2>();
|
||||
static constexpr auto idxM = I0;
|
||||
static constexpr auto idxN = I1;
|
||||
static constexpr auto idxK = I2;
|
||||
using BlockTile = remove_cvref_t<typename BlockGemmShape::BlockTile>;
|
||||
using BlockWarps = remove_cvref_t<typename BlockGemmShape::BlockWarps>;
|
||||
using WarpTile = remove_cvref_t<typename BlockGemmShape::WarpTile>;
|
||||
|
||||
static constexpr index_t MWarp = BlockWarps::at(I0);
|
||||
static constexpr index_t NWarp = BlockWarps::at(I1);
|
||||
|
||||
static constexpr index_t WarpTileM = WarpTile::at(I0);
|
||||
static constexpr index_t WarpTileN = WarpTile::at(I1);
|
||||
static constexpr index_t WarpTileK = WarpTile::at(I2);
|
||||
|
||||
static constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpTileM);
|
||||
static constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpTileN);
|
||||
static constexpr index_t KIterPerWarp = kKPerBlock / WarpTileK;
|
||||
|
||||
static constexpr index_t KFlatPerBlockPerIter = flatKPerWarp;
|
||||
static constexpr index_t NFlatPerBlockPerIter = flatNPerWarp;
|
||||
|
||||
static constexpr index_t MPerBlockPerIter = kMPerBlock / MIterPerWarp;
|
||||
static constexpr index_t KPerBlockPerIter = kKPerBlock / KIterPerWarp;
|
||||
|
||||
static constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
|
||||
static constexpr index_t m_preload = (MIterPerWarp * KIterPerWarp >= DsReadPreload)
|
||||
? DsReadPreload
|
||||
: MIterPerWarp * KIterPerWarp;
|
||||
|
||||
#ifdef __gfx942__
|
||||
static constexpr index_t mfma_per_wg = 2;
|
||||
#else
|
||||
static constexpr index_t mfma_per_wg = 1;
|
||||
#endif
|
||||
static constexpr index_t dsread_per_wg = max(
|
||||
index_t(WarpTileM * WarpTileK * sizeof(ADataType) / WaveSize / Problem::VectorLoadSize), 1);
|
||||
#if defined(__HIP_DEVICE_COMPILE__)
|
||||
static_assert((WarpTileM * WarpTileK * sizeof(ADataType) * MIterPerWarp / WaveSize) %
|
||||
Problem::VectorLoadSize ==
|
||||
0);
|
||||
#endif
|
||||
static constexpr index_t dsread_num_perK = WarpTileM * WarpTileK * sizeof(ADataType) *
|
||||
MIterPerWarp / WaveSize / Problem::VectorLoadSize;
|
||||
static constexpr index_t dswrite_num_perK = dsread_num_perK / (MWarp * NWarp);
|
||||
static constexpr index_t dswrite_rep = (dswrite_num_perK + MIterPerWarp - 1) / MIterPerWarp;
|
||||
static constexpr index_t Aload_num_perK = dswrite_num_perK;
|
||||
static constexpr index_t Aload_rep = dswrite_rep;
|
||||
static constexpr index_t Bload_num_perK = kNPerBlock * WarpTileK / NWarp / K1 / WaveSize;
|
||||
static constexpr index_t HalfMIter = (MIterPerWarp + 1) / 2;
|
||||
static constexpr index_t Bload_rep = (Bload_num_perK + HalfMIter - 1) / HalfMIter;
|
||||
|
||||
static constexpr index_t mfma_perM_perK = NIterPerWarp * mfma_per_wg;
|
||||
static constexpr index_t dswrite_mIter = (DsWritePreIssue - 1) % MIterPerWarp;
|
||||
static constexpr index_t dswrite_kIter = (DsWritePreIssue - 1) / MIterPerWarp;
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetPipelineName()
|
||||
{
|
||||
// clang-format off
|
||||
return "PRESHUFFLE_V2";
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
|
||||
{
|
||||
// clang-format off
|
||||
return concat('_', "pipeline_AGmemBGmemCRegV2",
|
||||
concat('x', kMPerBlock, kNPerBlock, kKPerBlock, BlockSize),
|
||||
concat('x', WarpTileM, WarpTileN, WarpTileK),
|
||||
concat('x', GetVectorSizeA(), GetVectorSizeB()),
|
||||
concat('x', kPadM, kPadN, kPadK));
|
||||
|
||||
// clang-format on
|
||||
}
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
|
||||
|
||||
static constexpr index_t Preshuffle = Problem::Preshuffle;
|
||||
using Base::UsePersistentKernel;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto TransposeC() { return Problem::TransposeC; }
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
|
||||
return DoubleSmemBuffer ? 2 * smem_size : smem_size;
|
||||
}
|
||||
|
||||
// dsread_perM: how many LDS reads want to issue in this M-iter
|
||||
// dswrite_perM: how many LDS writes you want to do this M-iter
|
||||
// load_perM: how many global loads VMEM want to do in this M-iter
|
||||
CK_TILE_HOST_DEVICE static constexpr auto
|
||||
SchedulerPerM(index_t dsread_perM, index_t dswrite_perM, index_t load_perM)
|
||||
{
|
||||
|
||||
// Init inst order
|
||||
index_t max_data_inst = dsread_perM > load_perM
|
||||
? (dsread_perM > dswrite_perM ? dsread_perM : dswrite_perM)
|
||||
: (load_perM > dswrite_perM ? load_perM : dswrite_perM);
|
||||
index_t sum_data_inst = dsread_perM + load_perM + dswrite_perM;
|
||||
index_t round_data_inst = ck_tile::integer_divide_ceil(sum_data_inst, mfma_perM_perK);
|
||||
|
||||
constexpr int kOrderCap = NIterPerWarp * 10;
|
||||
index_t inst_order[kOrderCap] = {};
|
||||
index_t index = 0;
|
||||
#pragma unroll
|
||||
// round-robin
|
||||
// Index: 0 1 2 3 4 5 ...
|
||||
// Value: 1 2 3 1 2 3 ...
|
||||
for(int j = 0; j < max_data_inst; j++)
|
||||
{
|
||||
if(dswrite_perM > j)
|
||||
{
|
||||
inst_order[index] = 1;
|
||||
index++;
|
||||
}
|
||||
if(load_perM > j)
|
||||
{
|
||||
inst_order[index] = 2;
|
||||
index++;
|
||||
}
|
||||
if(dsread_perM > j)
|
||||
{
|
||||
inst_order[index] = 3;
|
||||
index++;
|
||||
}
|
||||
}
|
||||
|
||||
// Schedule IGLP
|
||||
#pragma unroll
|
||||
for(int j = 0; j < mfma_perM_perK; j++)
|
||||
{
|
||||
index_t inst_idx = 0;
|
||||
if(j == 0)
|
||||
;
|
||||
else if(j == 1)
|
||||
inst_idx = mfma_perM_perK == 2 ? 1 : mfma_perM_perK - 2;
|
||||
else if(j == 2)
|
||||
inst_idx = mfma_perM_perK - 1;
|
||||
else
|
||||
inst_idx = mfma_perM_perK - j;
|
||||
|
||||
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
|
||||
|
||||
#pragma unroll
|
||||
for(int r = 0; r < round_data_inst; r++)
|
||||
{
|
||||
if(r % 2 == 0)
|
||||
{
|
||||
if(inst_order[inst_idx + r * mfma_perM_perK] == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
if(inst_order[inst_idx + r * mfma_perM_perK] == 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if(inst_order[inst_idx + r * mfma_perM_perK] == 3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 1)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x200, 1, 0); // DS write
|
||||
}
|
||||
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 2)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
}
|
||||
if(inst_order[(r + 1) * mfma_perM_perK - 1 - inst_idx] == 3)
|
||||
{
|
||||
__builtin_amdgcn_sched_group_barrier(0x100, 1, 0); // DS read
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto HotLoopScheduler()
|
||||
{
|
||||
// Keypoint of pipeline optimize is workload balance in time
|
||||
// instruction schedule example(128X256X256, 1X4, 16X16X128):
|
||||
// Iter MNK MFMA ds_read ds_write A_load b_load
|
||||
// -1 M6N0: 57 - 8 - -
|
||||
// -1 M6N1: 58 1 - - -
|
||||
// -1 M6N2: 59 - - 7 -
|
||||
// -1 M6N3: 60 2 - - -
|
||||
// -1 M7N0: 61 - - - -
|
||||
// -1 M7N1: 62 3 - - -
|
||||
// -1 M7N2: 63 - - 8 -
|
||||
// -1 M7N3: 64 4 - - -
|
||||
// 0 M0N0K0: 1 - - - 1
|
||||
// 0 M0N1: 2 5 - - -
|
||||
// 0 M0N2: 3 - - - 2
|
||||
// 0 M0N3: 4 6 - - -
|
||||
// 0 M1N0: 5 - - - 3
|
||||
// 0 M1N1: 6 7 - - -
|
||||
// 0 M1N2: 7 - - - 4
|
||||
// 0 M1N3: 8 8 - - -
|
||||
// 0 M2N0: 9 - - - 5
|
||||
// 0 M2N1: 10 9 - - -
|
||||
// 0 M2N2: 11 - - - 6
|
||||
// 0 M2N3: 12 10 - - -
|
||||
// 0 M3N0: 13 - 1 - 7
|
||||
// 0 M3N1: 14 11 - - -
|
||||
// 0 M3N2: 15 - - - 8
|
||||
// 0 M3N3: 16 12 - - -
|
||||
// 0 M4N0: 17 - 2 - -
|
||||
// 0 M4N1: 18 13 - - -
|
||||
// 0 M4N2: 19 - - 1 -
|
||||
// 0 M4N3: 20 14 - - -
|
||||
// 0 M5N0: 21 - 3 - -
|
||||
// 0 M5N1: 22 15 - - -
|
||||
// 0 M5N2: 23 - - 2 -
|
||||
// 0 M5N3: 24 16 - - -
|
||||
// 0 M6N0: 25 - 4 - -
|
||||
// 0 M6N1: 26 17 - - -
|
||||
// 0 M6N2: 27 - - 3 -
|
||||
// 0 M6N3: 28 18 - - -
|
||||
// 0 M7N0: 29 - - - -
|
||||
// 0 M7N1: 30 19 - - -
|
||||
// 0 M7N2: 31 - - 4 -
|
||||
// 0 M7N3: 32 20 - - -
|
||||
// 0 M0N0K1: 33 - - - 9
|
||||
// 0 M0N1: 34 21 - - -
|
||||
// 0 M0N2: 35 - - - 10
|
||||
// 0 M0N3: 36 22 - - -
|
||||
// 0 M1N0: 37 - - - 11
|
||||
// 0 M1N1: 38 23 - - -
|
||||
// 0 M1N2: 39 - - - 12
|
||||
// 0 M1N3: 40 24 - - -
|
||||
// 0 M2N0: 41 - - - 13
|
||||
// 0 M2N1: 42 25 - - -
|
||||
// 0 M2N2: 43 - - - 14
|
||||
// 0 M2N3: 44 26 - - -
|
||||
// 0 M3N0: 45 - 5 - 15
|
||||
// 0 M3N1: 46 27 - - -
|
||||
// 0 M3N2: 47 - - - 16
|
||||
// 0 M3N3: 48 28 - - -
|
||||
// 0 M4N0: 49 - 6 - -
|
||||
// 0 M4N1: 50 29 - - -
|
||||
// 0 M4N2: 51 - - 5 -
|
||||
// 0 M4N3: 52 30 - - -
|
||||
// 0 M5N0: 53 - 7 - -
|
||||
// 0 M5N1: 54 31 - - -
|
||||
// 0 M5N2: 55 - - 6 -
|
||||
// 0 M5N3: 56 32 - - -
|
||||
// 0 M6N0: 57 - 8 - -
|
||||
// 0 M6N1: 58 1 - - -
|
||||
// 0 M6N2: 59 - - 7 -
|
||||
// 0 M6N3: 60 2 - - -
|
||||
// 0 M7N0: 61 - - - -
|
||||
// 0 M7N1: 62 3 - - -
|
||||
// 0 M7N2: 63 - - 8 -
|
||||
// 0 M7N3: 64 4 - - -
|
||||
|
||||
#pragma unroll
|
||||
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
#pragma unroll
|
||||
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
index_t dsread_perM = 0;
|
||||
index_t dswrite_perM = 0;
|
||||
index_t load_perM = 0;
|
||||
|
||||
// Calculate ds_read number per M
|
||||
dsread_perM = dsread_per_wg;
|
||||
|
||||
// Calculate ds_write number per M
|
||||
if(mIter == 0)
|
||||
{
|
||||
dswrite_perM =
|
||||
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0
|
||||
? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep
|
||||
: 0;
|
||||
}
|
||||
else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
|
||||
{
|
||||
dswrite_perM = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
dswrite_perM = (dswrite_num_perK -
|
||||
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
|
||||
? dswrite_rep
|
||||
: 0;
|
||||
}
|
||||
// Add ds write when ds write data > needed
|
||||
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
|
||||
{
|
||||
if(mIter == MIterPerWarp - 1 - dswrite_mIter)
|
||||
dswrite_perM = 1;
|
||||
}
|
||||
|
||||
// Calculate buffer_load number per M
|
||||
if(mIter < HalfMIter)
|
||||
{
|
||||
load_perM =
|
||||
((Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0 ? Aload_rep
|
||||
: 0) +
|
||||
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
|
||||
: 0);
|
||||
}
|
||||
else
|
||||
{
|
||||
load_perM = (Aload_num_perK - (MIterPerWarp - 1 - mIter) * Aload_rep) > 0
|
||||
? Aload_rep
|
||||
: 0;
|
||||
}
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
}
|
||||
// Add Aload when Aload data > needed
|
||||
if(Aload_num_perK == 0)
|
||||
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto Last2ndHotLoopScheduler()
|
||||
{
|
||||
#pragma unroll
|
||||
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
#pragma unroll
|
||||
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
index_t dsread_perM = 0;
|
||||
index_t dswrite_perM = 0;
|
||||
index_t load_perM = 0;
|
||||
|
||||
// Calculate ds_read number per M
|
||||
dsread_perM = dsread_per_wg;
|
||||
|
||||
// Calculate ds_write number per M
|
||||
if(mIter == 0)
|
||||
{
|
||||
dswrite_perM =
|
||||
(dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep) > 0
|
||||
? dswrite_num_perK - (MIterPerWarp - DsWritePreIssue) * dswrite_rep
|
||||
: 0;
|
||||
}
|
||||
else if(mIter >= MIterPerWarp - DsWritePreIssue + 1)
|
||||
{
|
||||
dswrite_perM = 0;
|
||||
}
|
||||
else
|
||||
{
|
||||
dswrite_perM = (dswrite_num_perK -
|
||||
(MIterPerWarp - DsWritePreIssue - mIter) * dswrite_rep) > 0
|
||||
? dswrite_rep
|
||||
: 0;
|
||||
}
|
||||
// Add ds write when ds write data > needed
|
||||
if(dswrite_num_perK == 0 && kIter == (KIterPerWarp - 1 - dswrite_kIter))
|
||||
{
|
||||
if(mIter == MIterPerWarp - 1 - dswrite_mIter)
|
||||
dswrite_perM = 1;
|
||||
}
|
||||
|
||||
// Calculate buffer_load number per M
|
||||
if(mIter < HalfMIter)
|
||||
{
|
||||
load_perM =
|
||||
((Bload_num_perK - (HalfMIter - 1 - mIter) * Bload_rep) > 0 ? Bload_rep
|
||||
: 0);
|
||||
}
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
}
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto LastHotLoopScheduler()
|
||||
{
|
||||
#pragma unroll
|
||||
for(int kIter = 0; kIter < KIterPerWarp; kIter++)
|
||||
{
|
||||
#pragma unroll
|
||||
for(int mIter = 0; mIter < MIterPerWarp; mIter++)
|
||||
{
|
||||
index_t dsread_perM = 0;
|
||||
index_t dswrite_perM = 0;
|
||||
index_t load_perM = 0;
|
||||
|
||||
// Calculate ds_read number per M
|
||||
if((kIter * MIterPerWarp + mIter) < (KIterPerWarp * MIterPerWarp - m_preload))
|
||||
dsread_perM = dsread_per_wg;
|
||||
|
||||
SchedulerPerM(dsread_perM, dswrite_perM, load_perM);
|
||||
}
|
||||
}
|
||||
// __builtin_amdgcn_sched_barrier(0);
|
||||
}
|
||||
|
||||
struct PipelineImpl : public PipelineImplBase
|
||||
{
|
||||
using Base = PipelineImplBase;
|
||||
|
||||
template <bool HasHotLoop,
|
||||
TailNumber TailNum,
|
||||
typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr,
|
||||
index_t UnaryOpSize_ = 8>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
[[maybe_unused]] const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
static_assert(
|
||||
std::is_same_v<ADataType, remove_cvref_t<typename ADramBlockWindowTmp::DataType>>,
|
||||
"wrong!");
|
||||
|
||||
static_assert(kMPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}],
|
||||
"wrong!");
|
||||
static_assert(kKPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
|
||||
"wrong!");
|
||||
|
||||
// A tile in LDS
|
||||
constexpr index_t smem_size = PipelinePolicy::template GetSmemSize<Problem>();
|
||||
|
||||
constexpr auto a_lds_block_desc =
|
||||
PipelinePolicy::template MakeALdsBlockDescriptor<Problem>();
|
||||
|
||||
auto a_lds_blocks = generate_tuple(
|
||||
[&](auto i) {
|
||||
ADataType* p_a_lds = static_cast<ADataType*>(
|
||||
static_cast<void*>(static_cast<char*>(p_smem) + smem_size * i.value));
|
||||
return make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
|
||||
},
|
||||
number<2>{});
|
||||
|
||||
constexpr auto a_lds_load_tile_distr = make_static_tile_distribution(
|
||||
BlockWeightPreshuffle::MakeABlockDistributionEncode());
|
||||
auto&& windows_result =
|
||||
Base::GetAWindows(a_dram_block_window_tmp, a_lds_blocks, a_lds_load_tile_distr);
|
||||
auto&& a_copy_dram_window = windows_result.template get<0>();
|
||||
auto&& a_lds_windows = windows_result.template get<1>();
|
||||
auto a_copy_lds_windows = generate_tuple(
|
||||
[&](auto i) -> decltype(auto) { return a_lds_windows[i].template at<0>(); },
|
||||
number<2>{});
|
||||
// Block GEMM
|
||||
auto block_weight_preshuffle = BlockWeightPreshuffle();
|
||||
// Acc register tile
|
||||
auto c_block_tile = block_weight_preshuffle.MakeCBlockTile();
|
||||
|
||||
auto a_load_windows = generate_tuple(
|
||||
[&](auto i) -> decltype(auto) {
|
||||
return block_weight_preshuffle.MakeALoadWindows(a_copy_lds_windows[i]);
|
||||
},
|
||||
number<2>{});
|
||||
|
||||
// B flat DRAM window for load
|
||||
auto b_flat_distribution =
|
||||
PipelinePolicy::template MakeBFlatDramTileDistribution<Problem>();
|
||||
auto b_flat_dram_window = // tile_window_with_static_distribution
|
||||
make_tile_window(b_flat_dram_block_window_tmp
|
||||
.get_bottom_tensor_view(), // from kernel gemm_pad_views
|
||||
make_tuple(number<flatNPerWarp * NIterPerWarp>{},
|
||||
number<flatKPerWarp * KIterPerWarp>{}),
|
||||
b_flat_dram_block_window_tmp.get_window_origin(),
|
||||
b_flat_distribution);
|
||||
|
||||
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
|
||||
using BDramTileWindowStep = typename BFlatBlockWindowTmp::BottomTensorIndex;
|
||||
constexpr ADramTileWindowStep a_dram_tile_window_step = make_array(0, kKPerBlock);
|
||||
constexpr BDramTileWindowStep b_dram_tile_window_step = make_array(0, kflatKPerBlock);
|
||||
|
||||
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
|
||||
using ABlockTile =
|
||||
decltype(make_static_distributed_tensor<ADataType>(ABlockTileDistr{}));
|
||||
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using BBlockTile =
|
||||
decltype(make_static_distributed_tensor<BTypeToUse>(b_flat_distribution));
|
||||
|
||||
ABlockTile a_global_tile;
|
||||
BBlockTile b_global_tile[2];
|
||||
|
||||
// // Prefetch A0
|
||||
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
Base::GlobalPrefetch(b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
|
||||
|
||||
// Prefill A0
|
||||
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
|
||||
|
||||
// Prefetch A1
|
||||
Base::GlobalPrefetch(a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
|
||||
// initialize C
|
||||
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// preload A00,A10 from lds
|
||||
block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]);
|
||||
|
||||
__builtin_amdgcn_sched_barrier(0);
|
||||
// MAIN LOOP
|
||||
if constexpr(HasHotLoop)
|
||||
{
|
||||
index_t i_global_read = amd_wave_read_first_lane(2);
|
||||
do
|
||||
{
|
||||
{
|
||||
Base::GlobalPrefetch(
|
||||
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
|
||||
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
|
||||
Base::GlobalPrefetch(
|
||||
a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
block_weight_preshuffle(c_block_tile,
|
||||
a_load_windows[I0],
|
||||
b_global_tile[0],
|
||||
b_flat_distribution);
|
||||
|
||||
block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]);
|
||||
HotLoopScheduler();
|
||||
}
|
||||
{
|
||||
Base::GlobalPrefetch(
|
||||
b_global_tile[0], b_flat_dram_window, b_dram_tile_window_step);
|
||||
Base::LocalPrefill(a_copy_lds_windows[I0], a_global_tile);
|
||||
Base::GlobalPrefetch(
|
||||
a_global_tile, a_copy_dram_window, a_dram_tile_window_step);
|
||||
block_weight_preshuffle(c_block_tile,
|
||||
a_load_windows[I1],
|
||||
b_global_tile[1],
|
||||
b_flat_distribution);
|
||||
|
||||
block_weight_preshuffle.LocalPrefetch(a_load_windows[I0]);
|
||||
HotLoopScheduler();
|
||||
}
|
||||
i_global_read += 2;
|
||||
} while(i_global_read < num_loop);
|
||||
}
|
||||
|
||||
// tail
|
||||
if constexpr(TailNum == TailNumber::Even)
|
||||
{
|
||||
{
|
||||
Base::GlobalPrefetch(
|
||||
b_global_tile[1], b_flat_dram_window, b_dram_tile_window_step);
|
||||
Base::LocalPrefill(a_copy_lds_windows[I1], a_global_tile);
|
||||
block_weight_preshuffle(
|
||||
c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution);
|
||||
block_sync_lds();
|
||||
block_weight_preshuffle.LocalPrefetch(a_load_windows[I1]);
|
||||
Last2ndHotLoopScheduler();
|
||||
}
|
||||
{
|
||||
block_weight_preshuffle(
|
||||
c_block_tile, a_load_windows[I1], b_global_tile[1], b_flat_distribution);
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
}
|
||||
else if constexpr(TailNum == TailNumber::Odd)
|
||||
{
|
||||
block_weight_preshuffle(
|
||||
c_block_tile, a_load_windows[I0], b_global_tile[0], b_flat_distribution);
|
||||
LastHotLoopScheduler();
|
||||
}
|
||||
|
||||
return c_block_tile;
|
||||
}
|
||||
};
|
||||
|
||||
// called from universal gemm kernel
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename AElementFunction,
|
||||
typename BElementFunction,
|
||||
typename std::enable_if_t<is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
[[maybe_unused]] const AElementFunction& a_element_func,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
[[maybe_unused]] const BElementFunction& b_element_func,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp[number<0>{}],
|
||||
a_element_func,
|
||||
b_flat_dram_block_window_tmp[number<0>{}],
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
// called from general gemm kernel
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
void* p_smem) const
|
||||
{
|
||||
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto tail_number = Base::GetBlockLoopTailNum(num_loop);
|
||||
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr auto PassThrough = [](const ADataType& a) { return a; };
|
||||
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
|
||||
// called from grouped gemm kernel
|
||||
template <typename ADramBlockWindowTmp,
|
||||
typename BFlatBlockWindowTmp,
|
||||
typename std::enable_if_t<!is_detected<is_tuple, ADramBlockWindowTmp>::value &&
|
||||
!is_detected<is_tuple, BFlatBlockWindowTmp>::value,
|
||||
bool>* = nullptr>
|
||||
CK_TILE_DEVICE auto operator()(const ADramBlockWindowTmp& a_dram_block_window_tmp,
|
||||
const BFlatBlockWindowTmp& b_flat_dram_block_window_tmp,
|
||||
index_t num_loop,
|
||||
TailNumber tail_number,
|
||||
void* __restrict__ p_smem) const
|
||||
{
|
||||
const auto has_hot_loop = Base::BlockHasHotloop(num_loop);
|
||||
const auto RunPipeline = [&](auto hot_loop_, auto tail_num_) {
|
||||
constexpr auto PassThrough = [](const auto& x) { return x; };
|
||||
return PipelineImpl{}.template operator()<hot_loop_.value, tail_num_.value>(
|
||||
a_dram_block_window_tmp,
|
||||
PassThrough,
|
||||
b_flat_dram_block_window_tmp,
|
||||
num_loop,
|
||||
p_smem);
|
||||
};
|
||||
return Base::TailHandler(RunPipeline, has_hot_loop, tail_number);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user