diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index ab3c0df88d..ff41ac0d61 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -75,7 +75,6 @@ struct Default2DEpilogue CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) { - // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { @@ -101,6 +100,15 @@ struct Default2DEpilogue } } } + + template + CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, + const OAccTile& o_acc_tile, + const DsDramWindows& /* unused */, + void* = nullptr) + { + return operator()(o_dram_window_tmp, o_acc_tile); + } }; template @@ -114,6 +122,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; + using DsDataType = ck_tile::tuple<>; + using DsLayout = ck_tile::tuple<>; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = Problem::kMPerXdl; static constexpr index_t kNPerXdl = Problem::kNPerXdl; @@ -181,6 +191,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue static_assert(false, "Unsupported CLayout!"); } } + + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; } }; } // namespace ck_tile