// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. // SPDX-License-Identifier: MIT #pragma once #include "ck_tile/core.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/common/tensor_layout.hpp" namespace ck_tile { // this epilogue just store out a M*N matrix, row major template struct Default2DEpilogueProblem { using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; static constexpr bool kPadM = kPadM_; static constexpr bool kPadN = kPadN_; static constexpr bool UseRawStore = UseRawStore_; static constexpr memory_operation_enum MemoryOperation = MemoryOperation_; static constexpr index_t NumDTensor = 0; }; template struct DefaultGemm2DEpilogueProblem : public Default2DEpilogueProblem { using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; using CLayout = remove_cvref_t; using DsDataType = remove_cvref_t; using CDElementwise = remove_cvref_t; using DsLayout = remove_cvref_t; static constexpr index_t kMPerBlock = kM_; static constexpr index_t kNPerBlock = kN_; static constexpr index_t kMPerXdl = kMPerXdl_; static constexpr index_t kNPerXdl = kNPerXdl_; static constexpr index_t kKPerXdl = kKPerXdl_; static constexpr index_t isCTransposed = isCTransposed_; static constexpr index_t NumDTensor = DsDataType::size(); static_assert(NumDTensor == DsLayout::size(), "The size of DsDataType and DsLayout should be the same"); }; template struct Default2DEpilogue { using Problem = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadN = Problem::kPadN; static constexpr bool UseRawStore = Problem::UseRawStore; static constexpr memory_operation_enum MemoryOperation = Problem::MemoryOperation; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() { return 0; } // TODO: this function assume store out vector size is the same as OAccTile last dimension size // how do we fix this ? template CK_TILE_DEVICE auto operator()(ODramWindowTmp& o_dram_window_tmp, const OAccTile& o_acc_tile, const DsDramWindows& ds_dram_windows, void* = nullptr) const { constexpr bool is_partition_index = std::is_convertible_v; const auto storeOrUpdateTile = [&](const auto& o_tile) { // TODO: this is ugly if constexpr(UseRawStore && (kPadM || kPadN)) { if constexpr(MemoryOperation == memory_operation_enum::set) { if constexpr(is_partition_index) { store_tile_raw(o_dram_window_tmp, cast_tile(o_tile), /*partition_index=*/ds_dram_windows); } else { store_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); } } else { update_tile_raw(o_dram_window_tmp, cast_tile(o_tile)); } buffer_store_fence(); } else { if constexpr(MemoryOperation == memory_operation_enum::set) { if constexpr(is_partition_index) { store_tile(o_dram_window_tmp, cast_tile(o_tile), /*partition_index=*/ds_dram_windows); } else { store_tile(o_dram_window_tmp, cast_tile(o_tile)); } } else { if constexpr(is_partition_index) { update_tile(o_dram_window_tmp, cast_tile(o_tile), /*partition_index=*/ds_dram_windows); } else { update_tile(o_dram_window_tmp, cast_tile(o_tile)); } } } }; if constexpr(!std::is_same_v && !is_partition_index && Problem::NumDTensor >= 1) { using elementwise_result_t = decltype(load_tile( make_tile_window(ds_dram_windows[number<0>{}].get_bottom_tensor_view(), make_tuple(Problem::kMPerBlock, Problem::kNPerBlock), ds_dram_windows[number<0>{}].get_window_origin(), o_acc_tile.get_tile_distribution()))); elementwise_result_t elementwise_result; const auto d_tensor_tuple = generate_tuple( [&](auto idx) { const auto d_tile_window = make_tile_window(ds_dram_windows[idx], o_acc_tile.get_tile_distribution()); return load_tile(d_tile_window); }, number{}); const auto c_d_tuple = concat_tuple_of_reference( tie(elementwise_result, o_acc_tile), generate_tie([&](auto idx) -> const auto& { return d_tensor_tuple[idx]; }, number{})); tile_elementwise_inout_unpack(typename Problem::CDElementwise{}, c_d_tuple); storeOrUpdateTile(elementwise_result); } else { storeOrUpdateTile(o_acc_tile); } } }; template struct DefaultGemm2DEpilogue : public Default2DEpilogue { using Problem = remove_cvref_t; using AsDataType = remove_cvref_t; using BsDataType = remove_cvref_t; using AccDataType = remove_cvref_t; using ODataType = remove_cvref_t; static constexpr bool ADataTypeIsTuple = is_detected::value; static constexpr bool BDataTypeIsTuple = is_detected::value; using AsDataTypeTuple = std::conditional_t, remove_cvref_t>>; using BsDataTypeTuple = std::conditional_t, remove_cvref_t>>; using ADataType = remove_cvref_t{}, AsDataTypeTuple>>; using BDataType = remove_cvref_t{}, BsDataTypeTuple>>; // Used for weight-only quantization kernel, B would be dequantized to the same data type as A using BTypeToUse = std::conditional_t, ADataType, BDataType>; using DsDataType = remove_cvref_t; using DsLayout = remove_cvref_t; using CDElementwise = remove_cvref_t; using CLayout = remove_cvref_t; static constexpr index_t kMPerXdl = Problem::kMPerXdl; static constexpr index_t kNPerXdl = Problem::kNPerXdl; static constexpr index_t kKPerXdl = Problem::kKPerXdl; static constexpr index_t isCTransposed = Problem::isCTransposed; using WG = WarpGemmDispatcher; using CWarpDstr = typename WG::CWarpDstr; CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC() { // N is contiguous dimension if constexpr(std::is_same_v) { if constexpr(isCTransposed) { // In this case each thread has multiple consecutive elements in // N dimension, however consecutive threads' elements have stride. constexpr index_t NDimY = CWarpDstr::NDimY; constexpr auto c_warp_y_lengths = CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane == c_warp_y_lengths.get(number{})); return c_warp_y_lengths.get(number{}); } else { // In this case each thread has just a single item in Ndim return (WG::WarpGemmAttribute::Impl::kCNLane * WG::WarpGemmAttribute::Impl::kBNBlock) / WG::kN; } } // M is contiguous dimension else if constexpr(std::is_same_v) { if constexpr(isCTransposed) { // In this case each thread has just a single item in Mdim return (WG::WarpGemmAttribute::Impl::kCNLane * WG::WarpGemmAttribute::Impl::kAMBlock) / WG::kN; } else { // In this case each thread has multiple consecutive elements in // M dimension, however consecutive threads' elements have stride. constexpr index_t NDimY = CWarpDstr::NDimY; constexpr auto c_warp_y_lengths = CWarpDstr{}.get_ys_to_d_descriptor().get_lengths(); static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane == c_warp_y_lengths.get(number{})); return c_warp_y_lengths.get(number{}); } } else { static_assert(false, "Unsupported CLayout!"); } } template CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeD([[maybe_unused]] number index) { return GetVectorSizeC(); } }; } // namespace ck_tile