Files
composable_kernel/include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
Bartłomiej Kocot af66494880 [CK TILE] GEMM and Batched GEMM SplitK support (#1724)
* [CK TILE] Add split K support in GEMM

* Updates

* Fixes

* rebase

* fix

* Fix

* fixes

* support for batched gemm
2024-12-28 14:40:17 +01:00

75 lines
2.5 KiB
C++

// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// this epilogue just store out a M*N matrix, row major
template <typename AccDataType_,
typename ODataType_,
bool kPadM_,
bool kPadN_,
bool UseRawStore_ = true>
struct Default2DEpilogueProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool UseRawStore = UseRawStore_;
};
template <typename Problem_, typename Policy_ = void>
struct Default2DEpilogue
{
using Problem = remove_cvref_t<Problem_>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool UseRawStore = Problem::UseRawStore;
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; }
CK_TILE_HOST_DEVICE static constexpr bool IsOutputTransposed() { return false; }
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template <typename ODramWindowTmp,
typename OAccTile,
memory_operation_enum out_memory_data_op = memory_operation_enum::set>
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile)
{
// TODO: this is ugly
if constexpr(UseRawStore && (kPadM || kPadN))
{
if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
buffer_store_fence();
}
else
{
if constexpr(out_memory_data_op == memory_operation_enum::set)
{
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
else
{
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
}
}
}
};
} // namespace ck_tile