Reland fix default epilogue (#2367)

* Revert "Revert "Fix default epilogue  (#2358)" (#2364)"

This reverts commit 64a2fda713.

* add operator() with old signature
This commit is contained in:
Max Podkorytov
2025-06-19 10:39:30 -07:00
committed by GitHub
parent c8b247c55c
commit 11eb9f1c77

View File

@@ -75,7 +75,6 @@ struct Default2DEpilogue
CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
{
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
@@ -101,6 +100,15 @@ struct Default2DEpilogue
}
}
}
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const OAccTile& o_acc_tile,
const DsDramWindows& /* unused */,
void* = nullptr)
{
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
}
};
template <typename Problem_, typename Policy_ = void>
@@ -114,6 +122,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using DsDataType = ck_tile::tuple<>;
using DsLayout = ck_tile::tuple<>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
@@ -181,6 +191,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
static_assert(false, "Unsupported CLayout!");
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; }
};
} // namespace ck_tile