Fix default epilogue (#2358)

* [ck-tile] fix default epilogue in gemm universal

* argument validation needs vector size D

* operator() needs to specify dram windows

* copy/paste from cshuffle epilogue

* clang-format

* mark unused argument

---------

Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
Max Podkorytov
2025-06-17 17:30:21 -07:00
committed by GitHub
parent 0eb8974502
commit cd606f72c1

View File

@@ -71,9 +71,11 @@ struct Default2DEpilogue
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp, typename OAccTile>
CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const OAccTile& o_acc_tile,
const DsDramWindows& /* unused */,
void* = nullptr)
{
// TODO: this is ugly
@@ -114,6 +116,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using DsDataType = ck_tile::tuple<>;
using DsLayout = ck_tile::tuple<>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
@@ -181,6 +185,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
static_assert(false, "Unsupported CLayout!");
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; }
};
} // namespace ck_tile