fix for default epilogue (#2167)

This commit is contained in:
jakpiase
2025-05-07 19:46:53 +02:00
committed by GitHub
parent 397b9080a2
commit cb07ad84d5

View File

@@ -15,14 +15,16 @@ template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool UseRawStore_ = true>
bool UseRawStore_ = true,
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
struct Default2DEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
};
template <typename ADataType_,
@@ -36,9 +38,14 @@ template <typename ADataType_,
index_t kNPerXdl_,
index_t kKPerXdl_,
bool isCTransposed_,
bool UseRawStore_ = true>
struct DefaultGemm2DEpilogueProblem
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
bool UseRawStore_ = true,
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataType_,
ODataType_,
kPadM_,
kPadN_,
UseRawStore_,
MemoryOperation_>
{
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
@@ -58,14 +65,13 @@ struct Default2DEpilogue
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore;
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
{
@@ -73,7 +79,7 @@ struct Default2DEpilogue
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
if constexpr(out_memory_data_op == memory_operation_enum::set)
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
@@ -85,7 +91,7 @@ struct Default2DEpilogue
}
else
{
if constexpr(out_memory_data_op == memory_operation_enum::set)
if constexpr(MemoryOperation == memory_operation_enum::set)
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}