mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 04:31:25 +00:00
fix for default epilogue (#2167)
This commit is contained in:
@@ -15,14 +15,16 @@ template <typename AccDataType_,
|
||||
typename ODataType_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
bool UseRawStore_ = true>
|
||||
bool UseRawStore_ = true,
|
||||
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
|
||||
struct Default2DEpilogueProblem
|
||||
{
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool UseRawStore = UseRawStore_;
|
||||
using AccDataType = remove_cvref_t<AccDataType_>;
|
||||
using ODataType = remove_cvref_t<ODataType_>;
|
||||
static constexpr bool kPadM = kPadM_;
|
||||
static constexpr bool kPadN = kPadN_;
|
||||
static constexpr bool UseRawStore = UseRawStore_;
|
||||
static constexpr memory_operation_enum MemoryOperation = MemoryOperation_;
|
||||
};
|
||||
|
||||
template <typename ADataType_,
|
||||
@@ -36,9 +38,14 @@ template <typename ADataType_,
|
||||
index_t kNPerXdl_,
|
||||
index_t kKPerXdl_,
|
||||
bool isCTransposed_,
|
||||
bool UseRawStore_ = true>
|
||||
struct DefaultGemm2DEpilogueProblem
|
||||
: public Default2DEpilogueProblem<AccDataType_, ODataType_, kPadM_, kPadN_, UseRawStore_>
|
||||
bool UseRawStore_ = true,
|
||||
memory_operation_enum MemoryOperation_ = memory_operation_enum::set>
|
||||
struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataType_,
|
||||
ODataType_,
|
||||
kPadM_,
|
||||
kPadN_,
|
||||
UseRawStore_,
|
||||
MemoryOperation_>
|
||||
{
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
@@ -58,14 +65,13 @@ struct Default2DEpilogue
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool UseRawStore = Problem::UseRawStore;
|
||||
static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation;
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
template <typename ODramWindowTmp,
|
||||
typename OAccTile,
|
||||
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
|
||||
template <typename ODramWindowTmp, typename OAccTile>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr)
|
||||
{
|
||||
@@ -73,7 +79,7 @@ struct Default2DEpilogue
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
@@ -85,7 +91,7 @@ struct Default2DEpilogue
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(out_memory_data_op == memory_operation_enum::set)
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user