mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-29 11:16:59 +00:00
[GEMM] Refactor GetStaticLdsSize and remove GetSmemSize
This commit is contained in:
@@ -30,9 +30,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
|
||||
{
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
return Policy::template GetSmemSize<Problem>();
|
||||
#else
|
||||
return integer_divide_ceil(
|
||||
sizeof(ADataType) *
|
||||
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
|
||||
@@ -40,7 +37,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
16 +
|
||||
sizeof(BDataType) *
|
||||
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
|
||||
#endif
|
||||
}
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
@@ -67,7 +63,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
|
||||
|
||||
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
|
||||
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
|
||||
|
||||
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
|
||||
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }
|
||||
|
||||
@@ -336,23 +336,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
|
||||
using WG = typename BlockGemm::WarpGemm;
|
||||
using CWarpDstr = typename WG::CWarpDstr;
|
||||
|
||||
// In this case each thread has multiple consecutive elements in
|
||||
// N dimension, however consecutive threads' elements have stride.
|
||||
constexpr index_t NDimY = CWarpDstr::NDimY;
|
||||
constexpr auto c_warp_y_lengths =
|
||||
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
|
||||
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
|
||||
c_warp_y_lengths.get(number<NDimY - 1>{}));
|
||||
return c_warp_y_lengths.get(number<NDimY - 1>{});
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
|
||||
{
|
||||
@@ -362,44 +345,21 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
|
||||
constexpr index_t KPack = BlockGemm::Traits::KPack;
|
||||
return KPack;
|
||||
constexpr index_t kKPack = 8;
|
||||
return kKPack;
|
||||
// using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
|
||||
// constexpr index_t KPack = BlockGemm::Traits::KPack;
|
||||
// return KPack;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
|
||||
{
|
||||
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
|
||||
constexpr index_t KPack = BlockGemm::Traits::KPack;
|
||||
return KPack;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
|
||||
{
|
||||
constexpr auto a_lds_desc = MakeALdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_a = integer_least_multiple(
|
||||
sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16);
|
||||
return smem_size_a;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
|
||||
{
|
||||
constexpr auto b_lds_desc = MakeBLdsBlockDescriptor<Problem>();
|
||||
constexpr index_t smem_size_b = integer_least_multiple(
|
||||
sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16);
|
||||
return smem_size_b;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
|
||||
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
|
||||
|
||||
return smem_size_a + smem_size_b;
|
||||
constexpr index_t kKPack = 8;
|
||||
return kKPack;
|
||||
// using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
|
||||
// constexpr index_t KPack = BlockGemm::Traits::KPack;
|
||||
// return KPack;
|
||||
}
|
||||
|
||||
template <typename Problem>
|
||||
|
||||
@@ -61,28 +61,6 @@ struct TileGemmShape
|
||||
#endif
|
||||
|
||||
#if defined(ENABLE_INSTRUCTION_SCH)
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
typename ALayout_,
|
||||
typename BLayout_,
|
||||
typename CLayout_>
|
||||
struct TileGemmTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool kPadK = kPadK_;
|
||||
|
||||
// TODO this can't be hardcoded here! Should be in policy!
|
||||
static constexpr int _VectorSize = 16;
|
||||
|
||||
using ALayout = ALayout_;
|
||||
using BLayout = BLayout_;
|
||||
using CLayout = CLayout_;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
};
|
||||
|
||||
template <bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool kPadK_,
|
||||
@@ -91,7 +69,7 @@ template <bool kPadM_,
|
||||
typename BLayout_,
|
||||
typename CLayout_,
|
||||
bool TransposeC_ = false>
|
||||
struct TileGemmUniversalTraits
|
||||
struct TileGemmTraits
|
||||
{
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
@@ -284,22 +262,21 @@ struct Gemm
|
||||
PermuteA,
|
||||
PermuteB>;
|
||||
|
||||
using GemmUniversalTraits =
|
||||
TileGemmUniversalTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
/* ALayout */ tensor_layout::gemm::RowMajor,
|
||||
/* BLayout */ tensor_layout::gemm::ColumnMajor,
|
||||
/* CLayout */ tensor_layout::gemm::RowMajor,
|
||||
TransposeC>;
|
||||
using GemmTraits = TileGemmTraits<kPadM,
|
||||
kPadN,
|
||||
kPadK,
|
||||
DoubleSmemBuffer,
|
||||
/* ALayout */ tensor_layout::gemm::RowMajor,
|
||||
/* BLayout */ tensor_layout::gemm::ColumnMajor,
|
||||
/* CLayout */ tensor_layout::gemm::RowMajor,
|
||||
TransposeC>;
|
||||
|
||||
using BlockGemmPipelineProblem_ =
|
||||
BlockGemmPipelineProblem<ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
GemmShape,
|
||||
GemmUniversalTraits,
|
||||
GemmTraits,
|
||||
GemmPipelineScheduler::Intrawave,
|
||||
/* Has hot loop */ true,
|
||||
TailNumber::Full>;
|
||||
|
||||
Reference in New Issue
Block a user