// SPDX-License-Identifier: MIT // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. #pragma once #include "ck_tile/core.hpp" namespace ck_tile { // this epilogue just store out a M*N matrix, row major template struct Default2DEpilogueProblem { using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; static constexpr bool UseRawStore = UseRawStore_; }; template struct Default2DEpilogue { using Problem = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool UseRawStore = Problem::UseRawStore; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? template CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile) { // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { store_tile_raw(o_dram_window_tmp, cast_tile(o_acc_tile)); buffer_store_fence(); } else { store_tile(o_dram_window_tmp, cast_tile(o_acc_tile)); } } }; } // namespace ck_tile