Merge commit '11eb9f1c7711a419cfa0db5346c80edb1eaf7b4a' into develop

This commit is contained in:
assistant-librarian[bot]
2025-06-19 18:13:43 +00:00
parent 6502245d9a
commit 6653706e01

View File

@@ -75,7 +75,6 @@ struct Default2DEpilogue
CK_TILE_DEVICE auto
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
{
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
@@ -101,6 +100,15 @@ struct Default2DEpilogue
}
}
}
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
const OAccTile& o_acc_tile,
const DsDramWindows& /* unused */,
void* = nullptr)
{
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
}
};
template <typename Problem_, typename Policy_ = void>
@@ -114,6 +122,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
using BTypeToUse =
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
using DsDataType = ck_tile::tuple<>;
using DsLayout = ck_tile::tuple<>;
using CLayout = remove_cvref_t<typename Problem::CLayout>;
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
@@ -181,6 +191,8 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
static_assert(false, "Unsupported CLayout!");
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; }
};
} // namespace ck_tile