[CK-Tile] move out memory operation from cshuffle epilogue class (#3359)

* initial poc

* factor out common parts in operator()

* cv4

* rest of the universal gemm pipelines

* fix test

* remove boilerplate from tile engine

* fix example

* fix example

* format

* fix tests build for gemm

* remove base pipeline codegen from gemm instance builder

* unify v3 logic with the rest of universal gemm pipelines

* fix build for multi abd test

* fix test gemm multi d

* fix build for weight preshuffle

* fix grouped gemm test

* fix grouped gemm multi d test

* fix grouped gemm preshuffle

* fix grouped gemm example except for quant

* fix gemm preshuffle

* fix splitk 2 stage example

* fix batched gemm example

* fix multid example

* fix multiabd example

* fix batched gemm test

* fixup

* fix examples build

* fix grouped gemm test build

* fix smoke builder

* hacky poc

* fix tile engine

* kill the lambda

* maybe fix test build

* more fixes

* clang-format

* save temp

* clang-format

* mostly fix examples

* clang-format

* remove dead code

* more cleanup

* fix fmha bwd build (default epilogue set/add appears to be broken)

* fix default epilogue tests but not correctness

* clang-format

* fix bquant

* clang-format

* cleanup dead code

* rearrange make windows for readability

* restore changes to IsSupportedArgument

* fix smoke-builder

* clang-format

* fixup rename class

* build fixes

* clang-format

* fix builder

* fixup

* remove set from builder tests

* fix test

* clang-format

* re-refactor the kernels

* clang-format

* fix header license

* remove memory operation from conv bwd test

* clang-format

* clang-format example,include

* clang-format test

* build fixes

* clang-format

* solve compilation error

* fix the CI

* solve compilation error

* clang format

* solve merge conflict

* solve merge conflict

* solve the gfx11 error

* solve test error

* moar build fixes

* remove AtomicAddRequiresKBatchGreaterThanOne test since the property is removed from the kernel scope

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Max Podkorytov
2026-01-04 03:28:14 -08:00
committed by GitHub
parent ec23be0b9d
commit e339101e9c
68 changed files with 4198 additions and 4298 deletions

View File

@@ -30,7 +30,6 @@ template <typename AsDataType_,
index_t NPerXdl_,
index_t KPerXdl_,
bool isCTransposed_,
memory_operation_enum MemoryOperation_,
index_t kNumWaveGroups_ = 1,
bool FixedVectorSize_ = false,
index_t VectorSizeC_ = 1,
@@ -39,31 +38,30 @@ template <typename AsDataType_,
bool DoubleSmemBuffer_ = false>
struct CShuffleEpilogueProblem
{
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using DsDataType = remove_cvref_t<DsDataType_>;
using DsLayout = remove_cvref_t<DsLayout_>;
using ELayout = remove_cvref_t<ELayout_>;
using CDElementwise = remove_cvref_t<CDElementwise_>;
static constexpr index_t kBlockSize = MWave_ * NWave_ * get_warp_size();
static constexpr index_t kMPerBlock = kM_;
static constexpr index_t kNPerBlock = kN_;
static constexpr index_t MWave = MWave_;
static constexpr index_t NWave = NWave_;
static constexpr index_t MPerXdl = MPerXdl_;
static constexpr index_t NPerXdl = NPerXdl_;
static constexpr index_t KPerXdl = KPerXdl_;
static constexpr index_t isCTransposed = isCTransposed_;
static constexpr bool FixedVectorSize = FixedVectorSize_;
static constexpr index_t VectorSizeC = VectorSizeC_;
static constexpr index_t BlockedXDLN_PerWarp = BlockedXDLN_PerWarp_;
static constexpr bool DoubleSmemBuffer = DoubleSmemBuffer_;
static constexpr bool TiledMMAPermuteN = TiledMMAPermuteN_;
static constexpr index_t kNumWaveGroups = kNumWaveGroups_;
static constexpr index_t NumDTensor = DsDataType::size();
static_assert(NumDTensor == DsLayout::size(),
"The size of DsDataType and DsLayout should be the same");
@@ -105,28 +103,27 @@ struct CShuffleEpilogue
ADataType,
BDataType>;
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t MWave = Problem::MWave;
static constexpr index_t NWave = Problem::NWave;
static constexpr index_t MPerXdl = Problem::MPerXdl;
static constexpr index_t NPerXdl = Problem::NPerXdl;
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
using ELayout = remove_cvref_t<typename Problem::ELayout>;
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kMPerBlock = Problem::kMPerBlock;
static constexpr index_t kNPerBlock = Problem::kNPerBlock;
static constexpr index_t MWave = Problem::MWave;
static constexpr index_t NWave = Problem::NWave;
static constexpr index_t MPerXdl = Problem::MPerXdl;
static constexpr index_t NPerXdl = Problem::NPerXdl;
static constexpr index_t KPerXdl = Problem::KPerXdl;
static constexpr index_t isCTransposed = Problem::isCTransposed;
static constexpr bool FixedVectorSize = Problem::FixedVectorSize;
static constexpr bool TiledMMAPermuteN = Problem::TiledMMAPermuteN;
static constexpr index_t BlockedXDLN_PerWarp = Problem::BlockedXDLN_PerWarp;
static constexpr bool DoubleSmemBuffer = Problem::DoubleSmemBuffer;
static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr index_t MPerIteration = MPerXdl * MWave;
static constexpr index_t NPerIteration = NPerXdl * NWave;
static constexpr index_t NumDTensor = Problem::NumDTensor;
static constexpr index_t MRepeat = kMPerBlock / (MPerXdl * MWave);
static constexpr index_t NRepeat = kNPerBlock / (NPerXdl * NWave);
CDElementwise elfunc_;
@@ -142,8 +139,7 @@ struct CShuffleEpilogue
concat('x', MWave, NWave),
concat('x', MPerXdl, NPerXdl, KPerXdl),
VectorSizeC,
isCTransposed ? "CTransposed" : "CNotTransposed",
mem_op_string<MemoryOperation>());
isCTransposed ? "CTransposed" : "CNotTransposed");
// clang-format on
}
@@ -445,7 +441,8 @@ struct CShuffleEpilogue
CK_TILE_DEVICE void store_to_dram(OutDramWindow& out_dram_window,
const COutTensor& c_out_tensor)
{
if constexpr(MemoryOperation == memory_operation_enum::set)
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}
@@ -617,7 +614,8 @@ struct CShuffleEpilogue
});
// store/update
if constexpr(MemoryOperation == memory_operation_enum::set)
if constexpr(decltype(out_dram_window.get_bottom_tensor_view())::DstInMemOp ==
memory_operation_enum::set)
{
store_tile(out_dram_window, c_out_tensor);
}

View File

@@ -15,17 +15,15 @@ template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool UseRawStore_ = true,
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
bool UseRawStore_ = true>
struct Default2DEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
static constexpr index_t NumDTensor = 0;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
static constexpr index_t NumDTensor = 0;
};
template <typename AsDataType_,
@@ -44,14 +42,9 @@ template <typename AsDataType_,
index_t kNPerXdl_,
index_t kKPerXdl_,
bool isCTransposed_,
bool UseRawStore_ = true,
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataType_,
ODataType_,
kPadM_,
kPadN_,
UseRawStore_,
MemoryOperation_>
bool UseRawStore_ = true>
struct DefaultGemm2DEpilogueProblem
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
{
using AsDataType = remove_cvref_t<AsDataType_>;
using BsDataType = remove_cvref_t<BsDataType_>;
@@ -81,7 +74,6 @@ struct Default2DEpilogue
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
@@ -102,7 +94,10 @@ struct Default2DEpilogue
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
if constexpr(MemoryOperation == memory_operation_enum::set)
// FIXME?
// if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp ==
// memory_operation_enum::set)
if constexpr(true)
{
if constexpr(is_partition_index)
{
@@ -123,7 +118,10 @@ struct Default2DEpilogue
}
else
{
if constexpr(MemoryOperation == memory_operation_enum::set)
// FIXME?
// if constexpr(decltype(o_dram_window_tmp.get_bottom_tensor_view())::DstInMemOp ==
// memory_operation_enum::set)
if constexpr(true)
{
if constexpr(is_partition_index)
{