mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK-TILE] Default epilogue, adding support for D (#2629)
* Extend 2d-epilogue, D support * Added tests & update * Remove unused attribute * Extend tests --------- Co-authored-by: Thomas Ning <Thomas.Ning@amd.com>
This commit is contained in:
@@ -29,9 +29,14 @@ struct Default2DEpilogueProblem
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename DsDataType_,
|
||||
typename AccDataType_,
|
||||
typename ODataType_,
|
||||
typename DsLayout_,
|
||||
typename CLayout_,
|
||||
typename CDElementwise_,
|
||||
index_t kM_,
|
||||
index_t kN_,
|
||||
bool kPadM_,
|
||||
bool kPadN_,
|
||||
index_t kMPerXdl_,
|
||||
@@ -50,10 +55,20 @@ struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem<AccDataTyp
|
||||
using ADataType = remove_cvref_t<ADataType_>;
|
||||
using BDataType = remove_cvref_t<BDataType_>;
|
||||
using CLayout = remove_cvref_t<CLayout_>;
|
||||
using DsDataType = remove_cvref_t<DsDataType_>;
|
||||
using DsLayout = remove_cvref_t<DsLayout_>;
|
||||
using CDElementwise = remove_cvref_t<CDElementwise_>;
|
||||
static constexpr index_t kMPerBlock = kM_;
|
||||
static constexpr index_t kNPerBlock = kN_;
|
||||
static constexpr index_t kMPerXdl = kMPerXdl_;
|
||||
static constexpr index_t kNPerXdl = kNPerXdl_;
|
||||
static constexpr index_t kKPerXdl = kKPerXdl_;
|
||||
static constexpr index_t isCTransposed = isCTransposed_;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::size(),
|
||||
"The size of DsDataType and DsLayout should be the same");
|
||||
};
|
||||
|
||||
template <typename Problem_, typename Policy_ = void>
|
||||
@@ -62,6 +77,7 @@ struct Default2DEpilogue
|
||||
using Problem = remove_cvref_t<Problem_>;
|
||||
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
|
||||
using ODataType = remove_cvref_t<typename Problem::ODataType>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
static constexpr bool kPadM = Problem::kPadM;
|
||||
static constexpr bool kPadN = Problem::kPadN;
|
||||
static constexpr bool UseRawStore = Problem::UseRawStore;
|
||||
@@ -71,43 +87,70 @@ struct Default2DEpilogue
|
||||
|
||||
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
|
||||
// how do we fix this ?
|
||||
template <typename ODramWindowTmp, typename OAccTile>
|
||||
CK_TILE_DEVICE auto
|
||||
operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, void* = nullptr) const
|
||||
{
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_acc_tile));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename ODramWindowTmp, typename OAccTile, typename DsDramWindows>
|
||||
CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp,
|
||||
const OAccTile& o_acc_tile,
|
||||
const DsDramWindows& /* unused */,
|
||||
void* = nullptr) const
|
||||
const DsDramWindows& ds_dram_windows,
|
||||
void* = nullptr)
|
||||
{
|
||||
return operator()<ODramWindowTmp, OAccTile>(o_dram_window_tmp, o_acc_tile);
|
||||
const auto storeOrUpdateTile = [&](const auto& o_tile) {
|
||||
// TODO: this is ugly
|
||||
if constexpr(UseRawStore && (kPadM || kPadN))
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile_raw(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
buffer_store_fence();
|
||||
}
|
||||
else
|
||||
{
|
||||
if constexpr(MemoryOperation == memory_operation_enum::set)
|
||||
{
|
||||
store_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
else
|
||||
{
|
||||
update_tile(o_dram_window_tmp, cast_tile<ODataType>(o_tile));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr(Problem::NumDTensor >= 1)
|
||||
{
|
||||
using elementwise_result_t = decltype(load_tile(
|
||||
make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(),
|
||||
make_tuple(Problem::kMPerBlock, Problem::kNPerBlock),
|
||||
ds_dram_windows[number<0>{}].get_window_origin(),
|
||||
o_acc_tile.get_tile_distribution())));
|
||||
|
||||
elementwise_result_t elementwise_result;
|
||||
|
||||
const auto d_tensor_tuple = generate_tuple(
|
||||
[&](auto idx) {
|
||||
const auto d_tile_window =
|
||||
make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution());
|
||||
return load_tile(d_tile_window);
|
||||
},
|
||||
number<Problem::NumDTensor>{});
|
||||
|
||||
const auto c_d_tuple = concat_tuple_of_reference(
|
||||
tie(elementwise_result, o_acc_tile),
|
||||
generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; },
|
||||
number<Problem::NumDTensor>{}));
|
||||
|
||||
tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple);
|
||||
|
||||
storeOrUpdateTile(elementwise_result);
|
||||
}
|
||||
else
|
||||
{
|
||||
storeOrUpdateTile(o_acc_tile);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -122,8 +165,9 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
// Used for weight-only quantization kernel, B would be dequantized to the same data type as A
|
||||
using BTypeToUse =
|
||||
std::conditional_t<std::is_same_v<BDataType, pk_int4_t>, ADataType, BDataType>;
|
||||
using DsDataType = ck_tile::tuple<>;
|
||||
using DsLayout = ck_tile::tuple<>;
|
||||
using DsDataType = remove_cvref_t<typename Problem::DsDataType>;
|
||||
using DsLayout = remove_cvref_t<typename Problem::DsLayout>;
|
||||
using CDElementwise = remove_cvref_t<typename Problem::CDElementwise>;
|
||||
using CLayout = remove_cvref_t<typename Problem::CLayout>;
|
||||
static constexpr index_t kMPerXdl = Problem::kMPerXdl;
|
||||
static constexpr index_t kNPerXdl = Problem::kNPerXdl;
|
||||
@@ -192,7 +236,11 @@ struct DefaultGemm2DEpilogue : public Default2DEpilogue<Problem_, Policy_>
|
||||
}
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD() { return 1; }
|
||||
template <index_t I>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number<I> index)
|
||||
{
|
||||
return GetVectorSizeC();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
Reference in New Issue
Block a user