[CK_TILE] Update flatmm related kernels (#3022)

---------

Co-authored-by: Ding, Yi <yi.ding@amd.com>
Co-authored-by: felix <felix.li@amd.com>
This commit is contained in:
lalala-sh
2025-10-22 22:36:11 +08:00
committed by GitHub
parent cbd1279ae6
commit 211d64e18a
39 changed files with 11183 additions and 739 deletions

View File

@@ -310,4 +310,147 @@ struct UniversalGemmPipelineProblem
}
};
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full,
typename ComputeDataType_ = ADataType_>
struct FlatmmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::AsLayout>;
using BLayout = remove_cvref_t<typename Traits::BsLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr bool TransposeC = Traits::TransposeC;
static constexpr index_t NumWaveGroups = Traits::NumWaveGroups;
static constexpr bool UseStructuredSparsity = Traits::UseStructuredSparsity;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr bool DoubleSmemBuffer = Traits::DoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
// clang-format off
return concat('_', "gemm_problem",
concat('x', VectorLoadSize, kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler);
// clang-format on
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<ADataType>>::PackedSize;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(ADataType)
? pixels_per_thread
: PackedSize * VectorLoadSize / sizeof(ADataType);
}
else
{
return VectorLoadSize / sizeof(ADataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
{
constexpr index_t PackedSize =
ck_tile::numeric_traits<remove_cvref_t<BDataType>>::PackedSize;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < PackedSize * VectorLoadSize / sizeof(BDataType)
? pixels_per_thread
: PackedSize * VectorLoadSize / sizeof(BDataType);
}
else
{
return PackedSize * VectorLoadSize / sizeof(BDataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
{
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
constexpr index_t M0 = get_warp_size() / N2;
constexpr index_t M1 = BlockGemmShape::kM / M0;
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
else
{
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = BlockGemmShape::kN / N0;
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
}
static constexpr index_t VectorSizeA = []() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return kPadK ? 1 : GetAlignmentA();
}
else
{
return kPadM ? 1 : GetAlignmentA();
}
}();
static constexpr index_t VectorSizeB = []() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return kPadN ? 1 : GetAlignmentB();
}
else
{
return kPadK ? 1 : GetAlignmentB();
}
}();
static constexpr index_t VectorSizeC = []() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return kPadN ? 1 : GetAlignmentC();
}
else
{
return kPadM ? 1 : GetAlignmentC();
}
}();
};
} // namespace ck_tile