mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 08:50:17 +00:00
Fusion Conv+Bias+ReLU(+Add) (#62)
* fix relu * clean up * clean up * adding 1x1 conv * adding 1x1 conv * added 1x1 conv * refactor * refactor * refactor * added profiler for conv+bias+relu+add * clean up * adding conv+bias+relu * adding conv+bias+relu * added conv+bias+relu * Update README.md * update cpu verification * adding c shuffle * update static_tensor for dealing with invalid element * adding c shuffle * debugging * fix bug * convert to fp16 before shuffle * shuffle more than one M/NRepeat * clean up * remove coordinate step hack from GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1 * clean up * remove coordinate step hack from all gridwise gemm xdl * clean up coordinate step hack * clean up coordinate step hack * ThreadwiseTensorSliceTransfer_v3r2 support pointwise op on both src and dst * adding output shuffle in conv+bias+relu+add * update * added conv+bias+relu+add with c shuffle * added conv+bias+relu+add with c shuffle * fix forward_sweep bugs in threadwise copy * clean up * refactor * clean up * clean up * added conv_c_shuffle+bias_relu * clean up * added conv+bias+relu+atomic_add * clean up * clean up * clean up * clean up * clean up * clean up * misc fixes; add 1x1 specialization * clean up * delete unused device op * clean up * add support for odd C value
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION
|
||||
#define CONVOLUTION_FORWARD_SPECIALIZATION
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
enum ConvolutionForwardSpecialization_t
|
||||
{
|
||||
Default,
|
||||
Filter1x1Pad0,
|
||||
Filter1x1Stride1Pad0,
|
||||
OddC,
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,6 +1,8 @@
|
||||
#ifndef DEVICE_BASE_HPP
|
||||
#define DEVICE_BASE_HPP
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
@@ -32,6 +34,7 @@ struct BaseOperator
|
||||
BaseOperator& operator=(const BaseOperator&) = default;
|
||||
|
||||
virtual bool IsSupportedArgument(const BaseArgument*) = 0;
|
||||
virtual std::string GetTypeString() const = 0;
|
||||
|
||||
virtual ~BaseOperator() {}
|
||||
};
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
#ifndef DEVICE_CONV_HPP
|
||||
#define DEVICE_CONV_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
const void* p_wei,
|
||||
void* p_out,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvBwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(void* p_in,
|
||||
const void* p_wei,
|
||||
const void* p_out,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvWrw : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
void* p_wei,
|
||||
const void* p_out,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceConvFwdPtr = std::unique_ptr<
|
||||
DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceConvBwdPtr = std::unique_ptr<
|
||||
DeviceConvBwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceConvWrwPtr = std::unique_ptr<
|
||||
DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,944 @@
|
||||
#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_NHWC_KYXC_NHWK_HPP
|
||||
#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_ADD_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_conv_fwd_bias_activation_add.hpp"
|
||||
#include "convolution_forward_specialization.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v3r3.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// out[N, Ho, Wo, K] =
|
||||
// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K]) + residual[N, Ho, Wo, K]
|
||||
template <
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
|
||||
struct
|
||||
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
: public DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp =
|
||||
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
|
||||
|
||||
using ADataType = InDataType;
|
||||
using BDataType = WeiDataType;
|
||||
using CDataType = OutDataType;
|
||||
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
// TODO make it support any # of spatial dimensions
|
||||
static constexpr index_t NDimSpatial = 2;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t GemmMRaw = N * Ho * Wo;
|
||||
const index_t GemmN = K;
|
||||
|
||||
const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
|
||||
const auto GemmMPad = GemmM - GemmMRaw;
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
{ // 1x1, stride=1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmmraw_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
// C1: residual tensor: assume same layout as output tensor
|
||||
const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn,
|
||||
resi_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
{ // 1x1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ho_wo_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
// C1: residual tensor: assume same layout as output tensor
|
||||
const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn,
|
||||
resi_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC)
|
||||
{ // C = odd value
|
||||
const index_t GemmKRaw = Y * X * C;
|
||||
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
|
||||
const index_t GemmKPad = GemmK - GemmKRaw;
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmkraw_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkraw_gemmmraw_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_right_pad_transform(GemmKRaw, GemmKPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
// C1: residual tensor: assume same layout as output tensor
|
||||
const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn,
|
||||
resi_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmMRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
// C1: residual tensor: assume same layout as output tensor
|
||||
const auto resi_grid_desc_gemmm_gemmn = out_gemmm_gemmn_grid_desc;
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn,
|
||||
resi_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
}
|
||||
|
||||
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
using C0GridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I3])>;
|
||||
using C1GridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I4])>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r3<
|
||||
BlockSize,
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
C0GridDesc_M_N,
|
||||
C1GridDesc_M_N,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
const OutDataType* p_bias_grid,
|
||||
const OutDataType* p_resi_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
: p_a_grid_{p_in_grid},
|
||||
p_b_grid_{p_wei_grid},
|
||||
p_c_grid_{p_out_grid},
|
||||
p_c0_grid_{p_bias_grid},
|
||||
p_c1_grid_{p_resi_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c0_grid_desc_m_n_{},
|
||||
c1_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
c0_grid_desc_m_n_ = descs[I3];
|
||||
c1_grid_desc_m_n_ = descs[I4];
|
||||
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
|
||||
{
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c_grid_desc_m_n_);
|
||||
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c0_grid_desc_m_n_);
|
||||
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c1_grid_desc_m_n_);
|
||||
|
||||
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
}
|
||||
}
|
||||
|
||||
// private:
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
const CDataType* p_c0_grid_;
|
||||
const CDataType* p_c1_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
C0GridDesc_M_N c0_grid_desc_m_n_;
|
||||
C1GridDesc_M_N c1_grid_desc_m_n_;
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::
|
||||
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::
|
||||
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c1_grid_desc_m_n_{ " << arg.c1_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c1_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r3 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r3<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_c0_grid_,
|
||||
arg.p_c1_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r3<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
C1GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_c0_grid_,
|
||||
arg.p_c1_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.c1_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector load A/B matrix from global memory
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
|
||||
arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
const OutDataType* p_bias_grid,
|
||||
const OutDataType* p_resi_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
p_bias_grid,
|
||||
p_resi_grid,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
const void* p_wei_grid,
|
||||
void* p_out_grid,
|
||||
const void* p_bias_grid,
|
||||
const void* p_resi_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<const WeiDataType*>(p_wei_grid),
|
||||
static_cast<OutDataType*>(p_out_grid),
|
||||
static_cast<const OutDataType*>(p_bias_grid),
|
||||
static_cast<const OutDataType*>(p_resi_grid),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Add_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,892 @@
|
||||
#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP
|
||||
#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_BIAS_ACTIVATION_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_conv_fwd_bias_activation.hpp"
|
||||
#include "convolution_forward_specialization.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v3r2.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// out[N, Ho, Wo, K] =
|
||||
// activate(in[N, Hi, Wi, C] * wei[K, Y, X, C] + bias[K])
|
||||
template <
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
InMemoryDataOperationEnum_t OutGlobalMemoryDataOperation,
|
||||
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
|
||||
struct DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
: public DeviceConvFwdBiasActivation<InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp =
|
||||
DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
|
||||
|
||||
using ADataType = InDataType;
|
||||
using BDataType = WeiDataType;
|
||||
using CDataType = OutDataType;
|
||||
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
// TODO make it support any # of spatial dimensions
|
||||
static constexpr index_t NDimSpatial = 2;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t GemmMRaw = N * Ho * Wo;
|
||||
const index_t GemmN = K;
|
||||
|
||||
const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
|
||||
const auto GemmMPad = GemmM - GemmMRaw;
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
{ // 1x1, stride=1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmmraw_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
{ // 1x1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ho_wo_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC)
|
||||
{ // C = odd value
|
||||
const index_t GemmKRaw = Y * X * C;
|
||||
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
|
||||
const index_t GemmKPad = GemmK - GemmKRaw;
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmkraw_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkraw_gemmmraw_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_right_pad_transform(GemmKRaw, GemmKPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmMRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// C0: bias tensor: assume a contiguous vector
|
||||
const auto bias_grid_desc_gemmm_gemmn =
|
||||
make_naive_tensor_descriptor(make_tuple(GemmM, GemmN), make_tuple(I0, I1));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc,
|
||||
bias_grid_desc_gemmm_gemmn);
|
||||
}
|
||||
}
|
||||
|
||||
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
using C0GridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I3])>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r2<
|
||||
BlockSize,
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
OutGlobalMemoryDataOperation,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
C0GridDesc_M_N,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXDL,
|
||||
NPerXDL,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
const OutDataType* p_bias_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
: p_a_grid_{p_in_grid},
|
||||
p_b_grid_{p_wei_grid},
|
||||
p_c_grid_{p_out_grid},
|
||||
p_c0_grid_{p_bias_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c0_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
c0_grid_desc_m_n_ = descs[I3];
|
||||
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
|
||||
{
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c_grid_desc_m_n_);
|
||||
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c0_grid_desc_m_n_);
|
||||
|
||||
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
}
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
const CDataType* p_c0_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
C0GridDesc_M_N c0_grid_desc_m_n_;
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::
|
||||
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c0_grid_desc_m_n_{ " << arg.c0_grid_desc_m_n_.GetLength(I0)
|
||||
<< ", " << arg.c0_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r2 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r2<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_c0_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r2<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
C0GridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.p_c0_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.c0_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector load A/B matrix from global memory
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
|
||||
arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
const OutDataType* p_bias_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
p_bias_grid,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
const void* p_wei_grid,
|
||||
void* p_out_grid,
|
||||
const void* p_bias_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<const WeiDataType*>(p_wei_grid),
|
||||
static_cast<OutDataType*>(p_out_grid),
|
||||
static_cast<const OutDataType*>(p_bias_grid),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceConv2dFwdXdl_C_Shuffle_Bias_Activation_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,857 @@
|
||||
#ifndef DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
|
||||
#define DEVICE_CONV2D_FWD_XDL_C_SHUFFLE_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_conv_fwd.hpp"
|
||||
#include "convolution_forward_specialization.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v3r1.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
template <
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXdl,
|
||||
ck::index_t NPerXdl,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMXdlPerWavePerShuffle,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
|
||||
struct DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
|
||||
|
||||
using ADataType = InDataType;
|
||||
using BDataType = WeiDataType;
|
||||
using CDataType = OutDataType;
|
||||
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
static constexpr index_t NDimSpatial = 2;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static constexpr auto K1Number = Number<K1>{};
|
||||
static constexpr auto GemmK1Number = K1Number;
|
||||
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Hi = input_spatial_lengths[0];
|
||||
const index_t Wi = input_spatial_lengths[1];
|
||||
|
||||
const index_t Ho = output_spatial_lengths[0];
|
||||
const index_t Wo = output_spatial_lengths[1];
|
||||
|
||||
const index_t Y = filter_spatial_lengths[0];
|
||||
const index_t X = filter_spatial_lengths[1];
|
||||
|
||||
const index_t ConvStrideH = conv_filter_strides[0];
|
||||
const index_t ConvStrideW = conv_filter_strides[1];
|
||||
|
||||
const index_t ConvDilationH = conv_filter_dilations[0];
|
||||
const index_t ConvDilationW = conv_filter_dilations[1];
|
||||
|
||||
const index_t InLeftPadH = input_left_pads[0];
|
||||
const index_t InLeftPadW = input_left_pads[1];
|
||||
|
||||
const index_t InRightPadH = input_right_pads[0];
|
||||
const index_t InRightPadW = input_right_pads[1];
|
||||
|
||||
const index_t GemmMRaw = N * Ho * Wo;
|
||||
const index_t GemmN = K;
|
||||
|
||||
const auto GemmM = math::integer_least_multiple(GemmMRaw, MPerBlock);
|
||||
const auto GemmMPad = GemmM - GemmMRaw;
|
||||
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
{ // 1x1, stride=1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmmraw_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
{ // 1x1, pad=0
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ho_wo_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization == ConvolutionForwardSpecialization_t::OddC)
|
||||
{ // C = odd value
|
||||
const index_t GemmKRaw = Y * X * C;
|
||||
const index_t GemmK = math::integer_least_multiple(GemmKRaw, K0PerBlock * GemmK1Number);
|
||||
const index_t GemmKPad = GemmK - GemmKRaw;
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmkraw_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkraw_gemmmraw_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKRaw, GemmKPad),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K),
|
||||
make_right_pad_transform(GemmKRaw, GemmKPad)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
const index_t GemmK = Y * X * C;
|
||||
assert(GemmK % GemmK1Number == 0);
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmMRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
1, 1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v3r1<
|
||||
BlockSize,
|
||||
ABDataType, // TODO: distinguish A/B datatype
|
||||
AccDataType,
|
||||
CDataType,
|
||||
InMemoryDataOperationEnum_t::Set,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
CGridDesc_M_N,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerXdl,
|
||||
NPerXdl,
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
|
||||
2, // ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMXdlPerWavePerShuffle,
|
||||
CShuffleNXdlPerWavePerShuffle,
|
||||
CBlockTransferClusterLengths_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl,
|
||||
CBlockTransferScalarPerVector_NWaveNPerXdl>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
: p_a_grid_{p_in_grid},
|
||||
p_b_grid_{p_wei_grid},
|
||||
p_c_grid_{p_out_grid},
|
||||
a_grid_desc_k0_m_k1_{},
|
||||
b_grid_desc_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_{},
|
||||
block_2_ctile_map_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
if(GridwiseGemm::CheckValidity(
|
||||
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_))
|
||||
{
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_ =
|
||||
GridwiseGemm::
|
||||
MakeCGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl(
|
||||
c_grid_desc_m_n_);
|
||||
|
||||
block_2_ctile_map_ = GridwiseGemm::MakeBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
|
||||
}
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl
|
||||
c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_;
|
||||
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0)
|
||||
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
|
||||
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
|
||||
std::cout
|
||||
<< "arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_"
|
||||
"nwavenperxdl_{ "
|
||||
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
|
||||
.GetLength(I0)
|
||||
<< ", "
|
||||
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
|
||||
.GetLength(I1)
|
||||
<< ", "
|
||||
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
|
||||
.GetLength(I2)
|
||||
<< ", "
|
||||
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
|
||||
.GetLength(I3)
|
||||
<< ", "
|
||||
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
|
||||
.GetLength(I4)
|
||||
<< ", "
|
||||
<< arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_
|
||||
.GetLength(I5)
|
||||
<< "}" << std::endl;
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
|
||||
}
|
||||
|
||||
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
|
||||
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0);
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
else
|
||||
{
|
||||
const auto kernel = kernel_gemm_xdlops_v3r1<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<
|
||||
typename GridwiseGemm::
|
||||
CGridDescriptor_MBlock_MXdlPerWave_MWaveMPerXdl_NBlock_NXdlPerWave_NWaveNPerXdl>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(
|
||||
kernel,
|
||||
nrepeat,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mxdlperwave_mwavemperxdl_nblock_nxdlperwave_nwavenperxdl_,
|
||||
arg.in_element_op_,
|
||||
arg.wei_element_op_,
|
||||
arg.out_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector load A/B matrix from global memory
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
|
||||
arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_K_ % CBlockTransferScalarPerVector_NWaveNPerXdl == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.M01_,
|
||||
arg.N01_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
const WeiDataType* p_wei_grid,
|
||||
OutDataType* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
const void* p_wei_grid,
|
||||
void* p_out_grid,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<const WeiDataType*>(p_wei_grid),
|
||||
static_cast<OutDataType*>(p_out_grid),
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
1,
|
||||
1,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceConv2dFwdXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,23 +1,23 @@
|
||||
#ifndef DEVICE_CONV_FWD_XDL_NHWC_KYXC_NHWK_HPP
|
||||
#define DEVICE_CONV_FWD_XDL_NHWC_KYXC_NHWK_HPP
|
||||
#ifndef DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP
|
||||
#define DEVICE_CONV2D_FWD_XDL_NHWC_KYXC_NHWK_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_conv.hpp"
|
||||
#include "device_conv_fwd.hpp"
|
||||
#include "convolution_forward_specialization.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
#include "device_conv.hpp"
|
||||
#include "device_conv_fwd_xdl.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// specialization for 2D conv: in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
template <typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
@@ -25,6 +25,7 @@ template <typename InDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionForwardSpecialization_t ConvForwardSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
@@ -34,68 +35,27 @@ template <typename InDataType,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
bool ABlockLdsAddExtraM,
|
||||
bool BBlockLdsAddExtraN>
|
||||
struct DeviceConvFwdXdl<
|
||||
2, // ck::index_t NDimSpatial,
|
||||
InDataType, // typename InDataType,
|
||||
WeiDataType, // typename WeiDataType,
|
||||
OutDataType, // typename OutDataType,
|
||||
AccDataType, // typename AccDataType,
|
||||
ck::tensor_layout::convolution::NHWC, // typename InLayout,
|
||||
ck::tensor_layout::convolution::KYXC, // typename WeiLayout,
|
||||
ck::tensor_layout::convolution::NHWK, // typename OutLayout,
|
||||
InElementwiseOperation, // typename InElementwiseOperation,
|
||||
WeiElementwiseOperation, // typename WeiElementwiseOperation,
|
||||
OutElementwiseOperation, // typename OutElementwiseOperation,
|
||||
BlockSize, // ck::index_t BlockSize,
|
||||
MPerBlock, // ck::index_t MPerBlock,
|
||||
NPerBlock, // ck::index_t NPerBlock,
|
||||
K0PerBlock, // ck::index_t K0PerBlock,
|
||||
K1, // ck::index_t K1,
|
||||
MPerXDL, // ck::index_t MPerXDL,
|
||||
NPerXDL, // ck::index_t NPerXDL,
|
||||
MXdlPerWave, // ck::index_t MXdlPerWave,
|
||||
NXdlPerWave, // ck::index_t NXdlPerWave,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1, // typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1, // typename
|
||||
// ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder, // typename ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder, // typename ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim, // ck::index_t ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector, // ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1, // ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1, // typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1, // typename
|
||||
// BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder, // typename BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder, // typename BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim, // ck::index_t BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector, // ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1, // ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
CThreadTransferSrcDstVectorDim, // ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector, // ck::index_t CThreadTransferDstScalarPerVector,
|
||||
ABlockLdsAddExtraM, // bool ABlockLdsAddExtraM,
|
||||
BBlockLdsAddExtraN // bool BBlockLdsAddExtraN>
|
||||
>
|
||||
ck::index_t CThreadTransferDstScalarPerVector>
|
||||
struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
: public DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
|
||||
|
||||
using ADataType = InDataType;
|
||||
using BDataType = WeiDataType;
|
||||
using CDataType = OutDataType;
|
||||
@@ -103,7 +63,6 @@ struct DeviceConvFwdXdl<
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
// TODO make it support any # of spatial dimensions
|
||||
static constexpr index_t NDimSpatial = 2;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
@@ -159,88 +118,189 @@ struct DeviceConvFwdXdl<
|
||||
|
||||
const index_t GemmK0 = GemmK / GemmK1Number;
|
||||
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: input tensor
|
||||
const auto in_gemmmraw_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmmraw_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto in_gemmk_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmMRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc =
|
||||
transform_tensor_descriptor(in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
{
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
const auto in_n_ho_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Ho), make_tuple(ConvStrideH)),
|
||||
make_embed_transform(make_tuple(Wo), make_tuple(ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_n_ho_wo_c_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<3>{}, Sequence<0, 1, 2>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo), make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
// B: weight tensor
|
||||
const auto wei_gemmn_gemmk_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, C));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmn_gemmk_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
// C: output tensor
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
// A: input tensor
|
||||
const auto in_n_hi_wi_c_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N, Hi, Wi, C));
|
||||
|
||||
const auto in_n_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hi_wi_c_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}));
|
||||
|
||||
const auto in_n_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1, 2>{}, Sequence<3, 4>{}, Sequence<5>{}));
|
||||
|
||||
const auto in_gemmk_gemmmraw_grid_desc =
|
||||
transform_tensor_descriptor(in_n_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5>{}, Sequence<0, 2, 4>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmmraw_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk_gemmmraw_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmMRaw)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmk0_gemmmraw_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_k_yxc_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K, Y * X * C));
|
||||
|
||||
const auto wei_gemmk_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
wei_k_yxc_grid_desc,
|
||||
make_tuple(make_pass_through_transform(K), make_pass_through_transform(Y * X * C)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto wei_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// C: output tensor
|
||||
const auto out_nhowo_k_grid_desc =
|
||||
make_naive_tensor_descriptor_packed(make_tuple(N * Ho * Wo, K));
|
||||
|
||||
const auto out_gemmmraw_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_nhowo_k_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N * Ho * Wo),
|
||||
make_pass_through_transform(K)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmm_gemmn_grid_desc =
|
||||
transform_tensor_descriptor(out_gemmmraw_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmMRaw, GemmMPad),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(in_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
wei_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
out_gemmm_gemmn_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
using ABCGridDescs = decltype(MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
@@ -250,46 +310,6 @@ struct DeviceConvFwdXdl<
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
// TODO remove these hacks
|
||||
static constexpr auto a_k0_m_k1_grid_step_hacks = make_tuple(
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0>{})); // 2-: K1
|
||||
|
||||
static constexpr auto b_k0_n_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1+: N
|
||||
Sequence<0, 0, 0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0, 0, 0>{}, // 1-: N
|
||||
Sequence<0, 0, 0, 0, 0>{})); // 2-: K1
|
||||
|
||||
static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks =
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, 0, 0, 0>{};
|
||||
|
||||
static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
|
||||
BlockSize,
|
||||
@@ -311,7 +331,6 @@ struct DeviceConvFwdXdl<
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder,
|
||||
@@ -319,30 +338,18 @@ struct DeviceConvFwdXdl<
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder,
|
||||
Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder,
|
||||
2, // BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
Sequence<2, 3, 0, 1, 7, 5, 4, 6>, // CThreadTransferSrcDstAccessOrder,
|
||||
7, // CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks,
|
||||
decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks,
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks,
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks,
|
||||
false, // CAccessOrderMRepeatNRepeat,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockLdsAddExtraN>;
|
||||
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
|
||||
|
||||
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -377,19 +384,26 @@ struct DeviceConvFwdXdl<
|
||||
N01_{N01},
|
||||
in_element_op_{in_element_op},
|
||||
wei_element_op_{wei_element_op},
|
||||
out_element_op_{out_element_op}
|
||||
out_element_op_{out_element_op},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
filter_spatial_lengths_{filter_spatial_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
const auto descs = DeviceConvFwdXdl::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
a_grid_desc_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_k0_n_k1_ = descs[I1];
|
||||
@@ -412,19 +426,28 @@ struct DeviceConvFwdXdl<
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
|
||||
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
InElementwiseOperation in_element_op_;
|
||||
WeiElementwiseOperation wei_element_op_;
|
||||
OutElementwiseOperation out_element_op_;
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceConvFwdXdl::Argument;
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, int nrepeat = 1)
|
||||
{
|
||||
@@ -465,13 +488,13 @@ struct DeviceConvFwdXdl<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceConvFwdXdl::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceConvFwdXdl::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceConvFwdXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<DeviceConvFwdXdl::Block2CTileMap>,
|
||||
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
@@ -496,13 +519,13 @@ struct DeviceConvFwdXdl<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceConvFwdXdl::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceConvFwdXdl::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceConvFwdXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
remove_reference_t<DeviceConvFwdXdl::Block2CTileMap>,
|
||||
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
@@ -525,7 +548,6 @@ struct DeviceConvFwdXdl<
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
float Run(const BaseArgument* p_arg, int nrepeat = 1) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), nrepeat);
|
||||
@@ -540,6 +562,45 @@ struct DeviceConvFwdXdl<
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's 1x1, stride=1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.conv_filter_strides_[0] == 1 && arg.conv_filter_strides_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else if constexpr(ConvForwardSpecialization ==
|
||||
ConvolutionForwardSpecialization_t::Filter1x1Pad0)
|
||||
{
|
||||
// check if it's 1x1 conv
|
||||
if(!(arg.filter_spatial_lengths_[0] == 1 && arg.filter_spatial_lengths_[1] == 1 &&
|
||||
arg.input_left_pads_[0] == 0 && arg.input_left_pads_[1] == 0 &&
|
||||
arg.input_right_pads_[0] == 0 && arg.input_right_pads_[1] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// vector load A/B matrix from global memory
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && BBlockTransferSrcVectorDim == 2 &&
|
||||
arg.Conv_C_ % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_K_ % CThreadTransferDstScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_,
|
||||
arg.b_grid_desc_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
@@ -547,7 +608,6 @@ struct DeviceConvFwdXdl<
|
||||
arg.N01_);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
@@ -592,7 +652,6 @@ struct DeviceConvFwdXdl<
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
const void* p_wei_grid,
|
||||
@@ -631,11 +690,27 @@ struct DeviceConvFwdXdl<
|
||||
out_element_op);
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
}; // namespace device
|
||||
|
||||
} // namespace device
|
||||
46
device_operation/include/device_conv_fwd.hpp
Normal file
46
device_operation/include/device_conv_fwd.hpp
Normal file
@@ -0,0 +1,46 @@
|
||||
#ifndef DEVICE_CONV_FWD_HPP
|
||||
#define DEVICE_CONV_FWD_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvFwd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
const void* p_wei,
|
||||
void* p_out,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceConvFwdPtr = std::unique_ptr<
|
||||
DeviceConvFwd<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
49
device_operation/include/device_conv_fwd_bias_activation.hpp
Normal file
49
device_operation/include/device_conv_fwd_bias_activation.hpp
Normal file
@@ -0,0 +1,49 @@
|
||||
#ifndef DEVICE_CONV_FWD_BIAS_ACTIVATION_HPP
|
||||
#define DEVICE_CONV_FWD_BIAS_ACTIVATION_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvFwdBiasActivation : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
const void* p_wei,
|
||||
void* p_out,
|
||||
const void* p_bias,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceConvFwdBiasActivationPtr =
|
||||
std::unique_ptr<DeviceConvFwdBiasActivation<InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -0,0 +1,50 @@
|
||||
#ifndef DEVICE_CONV_FWD_BIAS_ACTIVATION_ADD_HPP
|
||||
#define DEVICE_CONV_FWD_BIAS_ACTIVATION_ADD_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include "device_base.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
struct DeviceConvFwdBiasActivationAdd : public BaseOperator
|
||||
{
|
||||
virtual std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in,
|
||||
const void* p_wei,
|
||||
void* p_out,
|
||||
const void* p_bias,
|
||||
const void* p_resi,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
template <typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation>
|
||||
using DeviceConvFwdBiasActivationAddPtr =
|
||||
std::unique_ptr<DeviceConvFwdBiasActivationAdd<InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>>;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,61 +0,0 @@
|
||||
#ifndef DEVICE_CONV_FWD_XDL_HPP
|
||||
#define DEVICE_CONV_FWD_XDL_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_conv.hpp"
|
||||
#include "common_header.hpp"
|
||||
#include "tensor_layout.hpp"
|
||||
#include "tensor_descriptor.hpp"
|
||||
#include "tensor_descriptor_helper.hpp"
|
||||
#include "gridwise_gemm_xdlops_v2r3.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerXDL,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
bool ABlockLdsAddExtraM,
|
||||
bool BBlockLdsAddExtraN>
|
||||
struct DeviceConvFwdXdl;
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,52 +0,0 @@
|
||||
#ifndef DEVICE_CONV_INSTANTCE_HPP
|
||||
#define DEVICE_CONV_INSTANTCE_HPP
|
||||
|
||||
#include "device_conv.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv_instance {
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
void add_device_conv_fwd_instance(
|
||||
std::vector<DeviceConvFwdPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>&);
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
void add_device_conv_bwd_instance(
|
||||
std::vector<DeviceConvBwdPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>&);
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
void add_device_conv_wrw_instance(
|
||||
std::vector<DeviceConvWrwPtr<ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough>>&);
|
||||
|
||||
} // namespace device_conv_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -2,6 +2,7 @@
|
||||
#define DEVICE_GEMM_XDL_HPP
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
#include "device.hpp"
|
||||
#include "device_base.hpp"
|
||||
#include "device_gemm.hpp"
|
||||
@@ -34,24 +35,22 @@ template <typename ADataType,
|
||||
ck::index_t NPerXDL,
|
||||
ck::index_t MXdlPerWave,
|
||||
ck::index_t NXdlPerWave,
|
||||
typename ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
ck::index_t ABlockTransferSrcVectorDim,
|
||||
ck::index_t ABlockTransferSrcScalarPerVector,
|
||||
ck::index_t ABlockTransferDstScalarPerVector_K1,
|
||||
typename BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
ck::index_t BBlockTransferSrcVectorDim,
|
||||
ck::index_t BBlockTransferSrcScalarPerVector,
|
||||
ck::index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
ck::index_t CThreadTransferSrcDstVectorDim,
|
||||
ck::index_t CThreadTransferDstScalarPerVector,
|
||||
bool ABlockLdsAddExtraM,
|
||||
bool BBlockLdsAddExtraN>
|
||||
ck::index_t CThreadTransferDstScalarPerVector>
|
||||
struct DeviceGemmXdl
|
||||
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
|
||||
{
|
||||
@@ -131,45 +130,6 @@ struct DeviceGemmXdl
|
||||
using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_K0_N_K1(1, 1, 1));
|
||||
using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1));
|
||||
|
||||
// TODO remove these hacks
|
||||
static constexpr auto a_k0_m_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: M
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: M
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
|
||||
static constexpr auto b_k0_n_k1_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0>{}, // 0+: K0
|
||||
Sequence<0, 0, 0>{}, // 1+: N
|
||||
Sequence<0, 0, 0>{}), // 2+: K1
|
||||
make_tuple(Sequence<0, 0, 0>{}, // 0-: K0
|
||||
Sequence<0, 0, 0>{}, // 1-: N
|
||||
Sequence<0, 0, 0>{})); // 2-: K1
|
||||
|
||||
static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks =
|
||||
make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2
|
||||
make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4
|
||||
Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2
|
||||
|
||||
static constexpr auto a_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
static constexpr auto b_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0>{};
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3<
|
||||
BlockSize,
|
||||
@@ -191,7 +151,6 @@ struct DeviceGemmXdl
|
||||
K1,
|
||||
MXdlPerWave,
|
||||
NXdlPerWave,
|
||||
ABlockTransferThreadSliceLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
@@ -199,30 +158,18 @@ struct DeviceGemmXdl
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false, // AThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockTransferThreadSliceLengths_K0_N_K1,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
false, // BThreadTransferSrcResetCoordinateAfterRun,
|
||||
BBlockLdsAddExtraN,
|
||||
Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder,
|
||||
CThreadTransferSrcDstVectorDim,
|
||||
CThreadTransferDstScalarPerVector,
|
||||
decltype(a_k0_m_k1_grid_step_hacks), // AGridStepHacks,
|
||||
decltype(b_k0_n_k1_grid_step_hacks), // BGridStepHacks,
|
||||
decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks,
|
||||
decltype(a_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks,
|
||||
decltype(b_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks,
|
||||
false, // CAccessOrderMRepeatNRepeat,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockLdsAddExtraN>;
|
||||
|
||||
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
|
||||
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
|
||||
|
||||
using Block2CTileMap = decltype(GridwiseGemm::MakeBlock2CTileMap(CGridDesc_M_N{}, 1, 1));
|
||||
CThreadTransferDstScalarPerVector>;
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
@@ -276,8 +223,9 @@ struct DeviceGemmXdl
|
||||
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
|
||||
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
|
||||
typename GridwiseGemm::Block2CTileMap block_2_ctile_map_;
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
AElementwiseOperation a_element_op_;
|
||||
@@ -331,11 +279,11 @@ struct DeviceGemmXdl
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
|
||||
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
|
||||
true>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
@@ -362,11 +310,11 @@ struct DeviceGemmXdl
|
||||
CDataType,
|
||||
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceGemmXdl::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
remove_reference_t<DeviceGemmXdl::Block2CTileMap>,
|
||||
remove_reference_t<typename GridwiseGemm::Block2CTileMap>,
|
||||
false>;
|
||||
|
||||
ave_time = launch_and_time_kernel(kernel,
|
||||
@@ -483,6 +431,24 @@ struct DeviceGemmXdl
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
// polymorphic
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGemmXdl"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
26
device_operation/include/device_operation_instance.hpp
Normal file
26
device_operation/include/device_operation_instance.hpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#ifndef CK_DEVICE_OPERATION_INSTANCE_HPP
|
||||
#define CK_DEVICE_OPERATION_INSTANCE_HPP
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <typename OpInstance, typename NewOpInstances>
|
||||
void add_device_operation_instances(std::vector<std::unique_ptr<OpInstance>>& op_instances,
|
||||
const NewOpInstances& new_op_instances)
|
||||
{
|
||||
ck::static_for<0, std::tuple_size_v<NewOpInstances>, 1>{}([&](auto i) {
|
||||
const auto new_op_instance = std::get<i>(new_op_instances);
|
||||
|
||||
using NewOpInstance = remove_cvref_t<decltype(new_op_instance)>;
|
||||
|
||||
op_instances.push_back(std::make_unique<NewOpInstance>(new_op_instance));
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -1,20 +0,0 @@
|
||||
#ifndef ELEMENT_WISE_OPERATION_HPP
|
||||
#define ELEMENT_WISE_OPERATION_HPP
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace element_wise {
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr T operator()(T v) const
|
||||
{
|
||||
return v;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
#endif
|
||||
Reference in New Issue
Block a user