NHWC Conv2d Bwd weight fp16 ckprofiler and test (#166)

* change backward weight name

* start add bwd weight lib and profiler

* change tuning paramter

* change output info

* add bwd weight test

* change test info

* using conv_util

* change wgt to weight

* add }

* add fp32
This commit is contained in:
ltqin
2022-04-05 09:32:00 +08:00
committed by GitHub
parent 82c8b9f8ee
commit 781cacd2e6
19 changed files with 814 additions and 42 deletions

View File

@@ -52,10 +52,13 @@ template <typename InDataType,
index_t CShuffleNXdlPerWavePerShuffle,
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>
struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
: public DeviceConvBwdWeight<InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation>
{
using DeviceOp = DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using DeviceOp =
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
using ADataType = OutDataType;
using BDataType = InDataType;
@@ -68,8 +71,6 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
// TODO make A/B datatype different
using ABDataType = InDataType;
static constexpr index_t NDimSpatial = 2;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
static constexpr auto I2 = Number<2>{};
@@ -691,7 +692,7 @@ struct DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_W
auto str = std::stringstream();
// clang-format off
str << "DeviceConv2dWrWXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
str << "DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K"
<< "<"
<< BlockSize << ", "
<< MPerBlock << ", "

View File

@@ -11,7 +11,7 @@ namespace device {
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
struct DeviceConvWrw : public BaseOperator
struct DeviceConvBwdWeight : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* p_in,
@@ -38,8 +38,8 @@ struct DeviceConvWrw : public BaseOperator
template <typename InElementwiseOperation,
typename WeiElementwiseOperation,
typename OutElementwiseOperation>
using DeviceConvWrwPtr = std::unique_ptr<
DeviceConvWrw<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
using DeviceConvBwdWeightPtr = std::unique_ptr<
DeviceConvBwdWeight<InElementwiseOperation, WeiElementwiseOperation, OutElementwiseOperation>>;
} // namespace device
} // namespace tensor_operation