From b29e3830a6b7676cac3610ace84fcc8589c993a7 Mon Sep 17 00:00:00 2001 From: Max Podkorytov <4273004+tenpercent@users.noreply.github.com> Date: Tue, 17 Jun 2025 17:30:21 -0700 Subject: [PATCH] Fix default epilogue (#2358) * [ck-tile] fix default epilogue in gemm universal * argument validation needs vector size D * operator() needs to specify dram windows * copy/paste from cshuffle epilogue * clang-format * mark unused argument --------- Co-authored-by: Thomas Ning [ROCm/composable_kernel commit: cd606f72c1fb3a99d596ad0f79521b46152764cb] --- include/ck_tile/ops/epilogue/default_2d_epilogue.hpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp index ab3c0df88d..623433c1dc 100644 --- a/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp +++ b/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp @@ -71,9 +71,11 @@ struct Default2DEpilogue // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? - template - CK_TILE_DEVICE auto - operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) + template + CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, + const OAccTile& o_acc_tile, + const DsDramWindows& /* unused */, + void* = nullptr) { // TODO: this is ugly @@ -114,6 +116,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; + using DsDataType = ck_tile::tuple<>; + using DsLayout = ck_tile::tuple<>; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = Problem::kMPerXdl; static constexpr index_t kNPerXdl = Problem::kNPerXdl; @@ -181,6 +185,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue static_assert(false, "Unsupported CLayout!"); } } + + CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; } }; } // namespace ck_tile