From 4e5e1d469ca44d81d322f46f12bb66670d6e0c3f Mon Sep 17 00:00:00 2001 From: jakpiase Date: Wed, 7 May 2025 19:46:53 +0200 Subject: [PATCH] fix for default epilogue (#2167) [ROCm/composable_kernel commit: cb07ad84d5b8a6a796dff34c5d990476b6693b16] --- .../ops/epilogue/default_2d_epilogue.hpp | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index 1d6a99eb4b..a2915f5c8f 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -15,14 +15,16 @@ template + bool UseRawStore_ = true, + memory_operation_enum MemoryOperation_ = memory_operation_enum::set> struct Default2DEpilogueProblem { - using AccDataType = remove_cvref_t; - using ODataType = remove_cvref_t; - static constexpr bool kPadM = kPadM_; - static constexpr bool kPadN = kPadN_; - static constexpr bool UseRawStore = UseRawStore_; + using AccDataType = remove_cvref_t; + using ODataType = remove_cvref_t; + static constexpr bool kPadM = kPadM_; + static constexpr bool kPadN = kPadN_; + static constexpr bool UseRawStore = UseRawStore_; + static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; }; template -struct DefaultGemm2DEpilogueProblem - : public Default2DEpilogueProblem + bool UseRawStore_ = true, + memory_operation_enum MemoryOperation_ = memory_operation_enum::set> +struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem { using ADataType = remove_cvref_t; using BDataType = remove_cvref_t; @@ -58,14 +65,13 @@ struct Default2DEpilogue static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool UseRawStore = Problem::UseRawStore; + static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? - template + template CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) { @@ -73,7 +79,7 @@ struct Default2DEpilogue // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { - if constexpr(out_memory_data_op == memory_operation_enum::set) + if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); } @@ -85,7 +91,7 @@ struct Default2DEpilogue } else { - if constexpr(out_memory_data_op == memory_operation_enum::set) + if constexpr(MemoryOperation == memory_operation_enum::set) { store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); }