mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 05:01:25 +00:00
Introduce multiABD api and deprecate multiD (#1035)
* Introduce multiABD api and deprecate multiD * Replace multiD with multiABD * Mark structures as deprecated * Change doxygen deprecated to note to avoid warnings
This commit is contained in:
@@ -0,0 +1,132 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
/**
|
||||
* \brief Grouped Convolution Forward
|
||||
*
|
||||
* \details
|
||||
* input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]...
|
||||
* input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]...
|
||||
* input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
|
||||
* output : output image E[G, N, K, Ho, Wo]
|
||||
*
|
||||
* C = a_op(A, A1...) * b_op(B, B1...)
|
||||
* E = cde_op(C, D0, D1, ...)
|
||||
*
|
||||
* \tparam NDimSpatial Number of spatial dimensions.
|
||||
* \tparam ALayout Input layout (also for a1, a2...).
|
||||
* \tparam BLayout Weight layout (also for b1, b2...).
|
||||
* \tparam DsLayout Ds layouts.
|
||||
* \tparam ELayout Output layout.
|
||||
* \tparam ADataType Input data type. Pass tuple if there is multiple A.
|
||||
* \tparam BDataType Weight data type. Pass tuple if there is multiple B.
|
||||
* \tparam DsDataType D data types.
|
||||
* \tparam EDataType Output data type.
|
||||
* \tparam AElementwiseOperation A elementwise operation.
|
||||
* \tparam BElementwiseOperation B elementwise operation.
|
||||
* \tparam CDEElementwiseOperation CDE elementwise operation.
|
||||
* \tparam ComputeType Compute data type (default: ADataType, first if tuple passed).
|
||||
*/
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename ADataType,
|
||||
typename BDataType,
|
||||
typename DsDataType,
|
||||
typename EDataType,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CDEElementwiseOperation,
|
||||
typename ComputeType =
|
||||
decltype(UnpackDataType<is_detected<is_tuple, ADataType>::value,
|
||||
Number<0>,
|
||||
ADataType>())> // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
struct DeviceGroupedConvFwdMultipleABD : public BaseOperator
|
||||
{
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
|
||||
|
||||
// If DataType is tuple, user has to pass std::array with pointers.
|
||||
using APointers =
|
||||
std::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>;
|
||||
using BPointers =
|
||||
std::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>;
|
||||
|
||||
/**
|
||||
* \brief Make argument pointer for grouped conv fwd.
|
||||
*
|
||||
* \param p_a A pointer to the input (std::array<const void*, NumA> with
|
||||
pointers for multiple A).
|
||||
* \param p_b A pointer to the weight (std::array<const void*, NumA> with
|
||||
pointers for multiple B).
|
||||
* \param p_ds A pointers to the Ds.
|
||||
* \param p_e A pointers to the output.
|
||||
* \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d).
|
||||
* \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d).
|
||||
* \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d).
|
||||
* \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d).
|
||||
* \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d).
|
||||
* \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d).
|
||||
* \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d).
|
||||
* \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d).
|
||||
* \param conv_filter_strides Convolution filter strides.
|
||||
* \param conv_filter_dilations Convolution filter dilations.
|
||||
* \param input_left_pads Input left paddings.
|
||||
* \param input_right_pads Input right paddings.
|
||||
* \param a_element_op A elementwise operation object.
|
||||
* \param b_element_op B elementwise operation object.
|
||||
* \param cde_element_op CDE elementwise operation object.
|
||||
* \return Pointer to the argument.
|
||||
*/
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
APointers p_a,
|
||||
BPointers p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -3,30 +3,18 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/device_base.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/utility/is_detected.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename T>
|
||||
using is_tuple = decltype(std::declval<T&>().IsTuple());
|
||||
|
||||
/**
|
||||
* \brief Grouped Convolution Forward
|
||||
*
|
||||
* \details
|
||||
* input : input image A[G, N, C, Hi, Wi], A1[G, N, C, Hi, Wi]...
|
||||
* input : weight B[G, K, C, Y, X], B1[G, K, C, Y, X]...
|
||||
* input : D0[G, N, K, Ho, Wo], D1[G, N, K, Ho, Wo], ...
|
||||
* output : output image E[G, N, K, Ho, Wo]
|
||||
*
|
||||
* C = a_op(A, A1...) * b_op(B, B1...)
|
||||
* E = cde_op(C, D0, D1, ...)
|
||||
* \note This structure is deprecated (left for backwards compatibility). Please use
|
||||
* DeviceGroupedConvFwdMultipleABD.
|
||||
*
|
||||
* \tparam NDimSpatial Number of spatial dimensions.
|
||||
* \tparam ALayout Input layout (also for a1, a2...).
|
||||
@@ -60,72 +48,19 @@ template <index_t NDimSpatial,
|
||||
ADataType>())> // ComputeType is InputType by default (first
|
||||
// in tuple for MultiAB), unpack if tuple was
|
||||
// passed
|
||||
struct DeviceGroupedConvFwdMultipleD : public BaseOperator
|
||||
{
|
||||
static constexpr bool isMultiA = is_detected<is_tuple, ADataType>::value;
|
||||
static constexpr bool isMultiB = is_detected<is_tuple, BDataType>::value;
|
||||
|
||||
static constexpr index_t NumATensor = GetNumABTensors<isMultiA, ADataType>();
|
||||
static constexpr index_t NumBTensor = GetNumABTensors<isMultiB, BDataType>();
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
static_assert(NumDTensor == DsLayout::Size(), "wrong! Inconsistent NumDTensor");
|
||||
|
||||
// If DataType is tuple, user has to pass std::array with pointers.
|
||||
using APointers =
|
||||
std::conditional_t<isMultiA, std::array<const void*, NumATensor>&, const void*>;
|
||||
using BPointers =
|
||||
std::conditional_t<isMultiB, std::array<const void*, NumBTensor>&, const void*>;
|
||||
|
||||
/**
|
||||
* \brief Make argument pointer for grouped conv fwd.
|
||||
*
|
||||
* \param p_a A pointer to the input (std::array<const void*, NumA> with
|
||||
pointers for multiple A).
|
||||
* \param p_b A pointer to the weight (std::array<const void*, NumA> with
|
||||
pointers for multiple B).
|
||||
* \param p_ds A pointers to the Ds.
|
||||
* \param p_e A pointers to the output.
|
||||
* \param a_g_n_c_wis_lengths Input lengths [G, N, C, Spatial...] (for 3d).
|
||||
* \param a_g_n_c_wis_strides Input strides [G, N, C, Spatial...] (for 3d).
|
||||
* \param b_g_k_c_xs_lengths Weight lengths [G, K, C, Spatial...] (for 3d).
|
||||
* \param b_g_k_c_xs_strides Weight strides [G, K, C, Spatial...] (for 3d).
|
||||
* \param ds_g_n_k_wos_lengths Ds lengths [G, N, K, Spatial...] (for 3d).
|
||||
* \param ds_g_n_k_wos_strides Ds strides [G, N, K, Spatial...] (for 3d).
|
||||
* \param e_g_n_k_wos_lengths Output lengths [G, N, K, Spatial...] (for 3d).
|
||||
* \param e_g_n_k_wos_strides Output strides [G, N, K, Spatial...] (for 3d).
|
||||
* \param conv_filter_strides Convolution filter strides.
|
||||
* \param conv_filter_dilations Convolution filter dilations.
|
||||
* \param input_left_pads Input left paddings.
|
||||
* \param input_right_pads Input right paddings.
|
||||
* \param a_element_op A elementwise operation object.
|
||||
* \param b_element_op B elementwise operation object.
|
||||
* \param cde_element_op CDE elementwise operation object.
|
||||
* \return Pointer to the argument.
|
||||
*/
|
||||
virtual std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
APointers p_a,
|
||||
BPointers p_b,
|
||||
const std::array<const void*, NumDTensor>& p_ds,
|
||||
void* p_e,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_lengths,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>& ds_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
using DeviceGroupedConvFwdMultipleD = DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation,
|
||||
ComputeType>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp"
|
||||
@@ -216,18 +216,18 @@ template <index_t NDimSpatial,
|
||||
index_t CThreadTransferSrcDstVectorDim,
|
||||
index_t CThreadTransferDstScalarPerVector>
|
||||
struct DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK
|
||||
: public DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK;
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1090,7 +1090,7 @@ struct DeviceGroupedConvFwdMultipleDMultipleR_Xdl_CShuffle
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvFwdMultipleD_Xdl_CShuffle"
|
||||
str << "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
@@ -92,18 +92,18 @@ template <index_t NDimSpatial,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
: public DeviceGroupedConvFwdMultipleD<NDimSpatial,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
: public DeviceGroupedConvFwdMultipleABD<NDimSpatial,
|
||||
ALayout,
|
||||
BLayout,
|
||||
DsLayout,
|
||||
ELayout,
|
||||
ADataType,
|
||||
BDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CDEElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvFwdMultipleD_Wmma_CShuffle;
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user