[GEMM] Refactor GetStaticLdsSize and remove GetSmemSize

This commit is contained in:
YC Lin
2025-04-10 14:22:22 +00:00
parent 04199bc0aa
commit 6fdf2bd896
3 changed files with 20 additions and 88 deletions

View File

@@ -30,9 +30,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetStaticLdsSize()
{
#if defined(ENABLE_INSTRUCTION_SCH)
return Policy::template GetSmemSize<Problem>();
#else
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
@@ -40,7 +37,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
#endif
}
#if defined(ENABLE_INSTRUCTION_SCH)
@@ -67,7 +63,6 @@ struct BlockGemmPipelineAGmemBGmemCReg
static constexpr index_t GetVectorSizeA() { return Policy::template GetVectorSizeA<Problem>(); }
static constexpr index_t GetVectorSizeB() { return Policy::template GetVectorSizeB<Problem>(); }
static constexpr index_t GetVectorSizeC() { return Policy::template GetVectorSizeC<Problem>(); }
static constexpr index_t GetSmemPackA() { return Policy::template GetSmemPackA<Problem>(); }
static constexpr index_t GetSmemPackB() { return Policy::template GetSmemPackB<Problem>(); }

View File

@@ -336,23 +336,6 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
using WG = typename BlockGemm::WarpGemm;
using CWarpDstr = typename WG::CWarpDstr;
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
@@ -362,44 +345,21 @@ struct BlockGemmPipelineAGmemBGmemCRegDefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
constexpr index_t KPack = BlockGemm::Traits::KPack;
return KPack;
constexpr index_t kKPack = 8;
return kKPack;
// using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
// constexpr index_t KPack = BlockGemm::Traits::KPack;
// return KPack;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
constexpr index_t KPack = BlockGemm::Traits::KPack;
return KPack;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr auto a_lds_desc = MakeALdsBlockDescriptor<Problem>();
constexpr index_t smem_size_a = integer_least_multiple(
sizeof(typename Problem::ADataType) * a_lds_desc.get_element_space_size(), 16);
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr auto b_lds_desc = MakeBLdsBlockDescriptor<Problem>();
constexpr index_t smem_size_b = integer_least_multiple(
sizeof(typename Problem::BDataType) * b_lds_desc.get_element_space_size(), 16);
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
return smem_size_a + smem_size_b;
constexpr index_t kKPack = 8;
return kKPack;
// using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
// constexpr index_t KPack = BlockGemm::Traits::KPack;
// return KPack;
}
template <typename Problem>

View File

@@ -61,28 +61,6 @@ struct TileGemmShape
#endif
#if defined(ENABLE_INSTRUCTION_SCH)
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
typename ALayout_,
typename BLayout_,
typename CLayout_>
struct TileGemmTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
// TODO this can't be hardcoded here! Should be in policy!
static constexpr int _VectorSize = 16;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
static constexpr bool TransposeC = false;
};
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
@@ -91,7 +69,7 @@ template <bool kPadM_,
typename BLayout_,
typename CLayout_,
bool TransposeC_ = false>
struct TileGemmUniversalTraits
struct TileGemmTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
@@ -284,22 +262,21 @@ struct Gemm
PermuteA,
PermuteB>;
using GemmUniversalTraits =
TileGemmUniversalTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
/* ALayout */ tensor_layout::gemm::RowMajor,
/* BLayout */ tensor_layout::gemm::ColumnMajor,
/* CLayout */ tensor_layout::gemm::RowMajor,
TransposeC>;
using GemmTraits = TileGemmTraits<kPadM,
kPadN,
kPadK,
DoubleSmemBuffer,
/* ALayout */ tensor_layout::gemm::RowMajor,
/* BLayout */ tensor_layout::gemm::ColumnMajor,
/* CLayout */ tensor_layout::gemm::RowMajor,
TransposeC>;
using BlockGemmPipelineProblem_ =
BlockGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
GemmUniversalTraits,
GemmTraits,
GemmPipelineScheduler::Intrawave,
/* Has hot loop */ true,
TailNumber::Full>;