mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
Add client example of grouped conv2d backward weight (data type: fp16) (#498)
* Remove redundant CMake setting * Extract common code from files * Rename folder 'convnd' to 'conv' * Use std::array<> to accept compile-time kwnown # of arguments * Fix compilation error of tuning parameter * In example, use same setting as unit-test * Remove no-longer used include directive * Add interface for grouped conv bwd weight * Add group support for conv bwd weight * Add grouped conv bwd weight example * Use group parameter in example * Rename example folder * Remove non-grouped version example source files * Rename device op template * Add group support to convolution backward weight * Remove debug messages * Use smaller group size in example * Use named variable as loop terminate condition * Prettify example output message * Enlarge used grid size * Allow real grid size exceeds expected grid size * Rename interface file * Add client example for grouped conv2d bwd weight * Fix wrong include directive * Rename client example folder
This commit is contained in:
@@ -67,6 +67,8 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
static constexpr ck::index_t NDimSpatial = 2;
|
||||
|
||||
using DeviceOp =
|
||||
DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K;
|
||||
|
||||
@@ -107,18 +109,18 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
static constexpr auto BBlockLdsN0PerBlock = NPerBlock / BBlockLdsN1PerBlock;
|
||||
static constexpr auto BBlockLdsN1Padding = 4;
|
||||
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -390,13 +392,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
@@ -473,11 +475,11 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> output_spatial_lengths_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
std::array<index_t, NDimSpatial> output_spatial_lengths_;
|
||||
std::array<index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
index_t k_batch_;
|
||||
};
|
||||
|
||||
@@ -682,13 +684,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -724,13 +726,13 @@ struct DeviceConv2dBwdWeightXdl_C_Shuffle_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
|
||||
@@ -4,13 +4,14 @@
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
@@ -20,6 +21,103 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
namespace {
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
typename AElementwiseOperation,
|
||||
typename BElementwiseOperation,
|
||||
typename CElementwiseOperation,
|
||||
typename AGridDesc_B_K0_M_K1,
|
||||
typename BGridDesc_B_K0_N_K1,
|
||||
typename CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename Block2CTileMap,
|
||||
typename ComputePtrOffsetOfBatch,
|
||||
bool HasMainKBlockLoop>
|
||||
__global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_batched_gemm_xdlops_bwd_weight(
|
||||
const FloatAB* __restrict__ p_a_grid,
|
||||
const FloatAB* __restrict__ p_b_grid,
|
||||
FloatC* __restrict__ p_c_grid,
|
||||
const AElementwiseOperation a_element_op,
|
||||
const BElementwiseOperation b_element_op,
|
||||
const CElementwiseOperation c_element_op,
|
||||
const index_t batch_count,
|
||||
const AGridDesc_B_K0_M_K1 a_b_k0_m_k1_grid_desc,
|
||||
const BGridDesc_B_K0_N_K1 b_b_k0_n_k1_grid_desc,
|
||||
const CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
const Block2CTileMap block_2_ctile_map,
|
||||
const ComputePtrOffsetOfBatch compute_ptr_offset_of_batch)
|
||||
{
|
||||
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
|
||||
const index_t num_blocks_per_batch =
|
||||
__builtin_amdgcn_readfirstlane(get_grid_size() / batch_count);
|
||||
const index_t g_idx = __builtin_amdgcn_readfirstlane(get_block_1d_id() / num_blocks_per_batch);
|
||||
|
||||
const long_index_t a_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetAPtrOffset(g_idx)));
|
||||
const long_index_t b_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetBPtrOffset(g_idx)));
|
||||
const long_index_t c_batch_offset = __builtin_amdgcn_readfirstlane(
|
||||
static_cast<long_index_t>(compute_ptr_offset_of_batch.GetCPtrOffset(g_idx)));
|
||||
|
||||
__shared__ FloatAB p_shared[GridwiseGemm::GetSharedMemoryNumberOfByte() / sizeof(FloatAB)];
|
||||
|
||||
GridwiseGemm::template Run<HasMainKBlockLoop>(p_a_grid + a_batch_offset,
|
||||
p_b_grid + b_batch_offset,
|
||||
p_c_grid + c_batch_offset,
|
||||
p_shared,
|
||||
a_b_k0_m_k1_grid_desc,
|
||||
b_b_k0_n_k1_grid_desc,
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
c_element_op,
|
||||
block_2_ctile_map);
|
||||
#else
|
||||
ignore = p_a_grid;
|
||||
ignore = p_b_grid;
|
||||
ignore = p_c_grid;
|
||||
ignore = a_b_k0_m_k1_grid_desc;
|
||||
ignore = b_b_k0_n_k1_grid_desc;
|
||||
ignore = c_grid_desc_mblock_mperblock_nblock_nperblock;
|
||||
ignore = a_element_op;
|
||||
ignore = b_element_op;
|
||||
ignore = c_element_op;
|
||||
ignore = block_2_ctile_map;
|
||||
ignore = compute_ptr_offset_of_batch;
|
||||
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
|
||||
}
|
||||
|
||||
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename InDataType,
|
||||
@@ -57,21 +155,21 @@ template <ck::index_t NDimSpatial,
|
||||
index_t CShuffleNXdlPerWavePerShuffle,
|
||||
typename CBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CBlockTransferScalarPerVector_NWaveNPerXdl>
|
||||
struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
: public DeviceConvBwdWeight<
|
||||
struct DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle
|
||||
: public DeviceGroupedConvBwdWeight<
|
||||
NDimSpatial,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::NWC,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::NDHWC>>,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GNDHWC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::KXC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::KZYXC>>,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::NWK,
|
||||
ck::tensor_layout::convolution::NHWK,
|
||||
ck::tensor_layout::convolution::NDHWK>>,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::GNDHWK>>,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
@@ -79,7 +177,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle;
|
||||
using DeviceOp = DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle;
|
||||
|
||||
using ADataType = OutDataType;
|
||||
using BDataType = InDataType;
|
||||
@@ -117,18 +215,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
static constexpr auto BBlockLdsN1Padding = 4;
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -269,18 +367,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -436,18 +534,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto
|
||||
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t batch_k)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
@@ -664,8 +762,8 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
}
|
||||
|
||||
template <index_t Dim>
|
||||
static auto MakeDescriptor_M0(const std::vector<index_t>& shape,
|
||||
const std::vector<index_t>& stride,
|
||||
static auto MakeDescriptor_M0(const std::array<index_t, Dim>& shape,
|
||||
const std::array<index_t, Dim>& stride,
|
||||
index_t gridSize,
|
||||
index_t blockSize)
|
||||
{
|
||||
@@ -759,16 +857,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
ck::index_t M01,
|
||||
ck::index_t N01,
|
||||
InElementwiseOperation in_element_op,
|
||||
@@ -783,11 +882,13 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_ctile_map_{},
|
||||
compute_ptr_offset_of_batch_{},
|
||||
M01_{M01},
|
||||
N01_{N01},
|
||||
a_element_op_{out_element_op},
|
||||
b_element_op_{in_element_op},
|
||||
c_element_op_{wei_element_op},
|
||||
Conv_G_{G},
|
||||
Conv_N_{N},
|
||||
Conv_K_{K},
|
||||
Conv_C_{C},
|
||||
@@ -819,6 +920,26 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
block_2_ctile_map_ =
|
||||
GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, k_batch_);
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ =
|
||||
N * K *
|
||||
std::accumulate(begin(output_spatial_lengths),
|
||||
end(output_spatial_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ =
|
||||
N * C *
|
||||
std::accumulate(begin(input_spatial_lengths),
|
||||
end(input_spatial_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
compute_ptr_offset_of_batch_.BatchStrideC_ =
|
||||
K * C *
|
||||
std::accumulate(begin(filter_spatial_lengths),
|
||||
end(filter_spatial_lengths),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
|
||||
b_grid_desc_kbatch_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
@@ -836,21 +957,29 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
|
||||
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
|
||||
InElementwiseOperation a_element_op_;
|
||||
OutElementwiseOperation b_element_op_;
|
||||
WeiElementwiseOperation c_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
index_t Conv_G_;
|
||||
index_t Conv_N_;
|
||||
index_t Conv_K_;
|
||||
index_t Conv_C_;
|
||||
std::vector<index_t> output_spatial_lengths_;
|
||||
std::vector<index_t> filter_spatial_lengths_;
|
||||
std::vector<index_t> conv_filter_strides_;
|
||||
std::vector<index_t> input_left_pads_;
|
||||
std::vector<index_t> input_right_pads_;
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads_;
|
||||
index_t k_batch_;
|
||||
};
|
||||
|
||||
@@ -873,14 +1002,12 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
ShowInfo(arg);
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
@@ -891,7 +1018,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_);
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_;
|
||||
|
||||
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
|
||||
@@ -900,17 +1027,18 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_gemm_xdlops_bwd_weight<
|
||||
const auto kernel = kernel_batched_gemm_xdlops_bwd_weight<
|
||||
GridwiseGemm,
|
||||
ADataType, // TODO: distiguish A/B datatype
|
||||
CDataType,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
@@ -921,13 +1049,15 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_c_grid_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.block_2_ctile_map_);
|
||||
arg.Conv_G_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
@@ -998,16 +1128,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
static auto MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -1016,6 +1147,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
@@ -1040,16 +1172,17 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
std::vector<ck::index_t> input_spatial_lengths,
|
||||
std::vector<ck::index_t> filter_spatial_lengths,
|
||||
std::vector<ck::index_t> output_spatial_lengths,
|
||||
std::vector<ck::index_t> conv_filter_strides,
|
||||
std::vector<ck::index_t> conv_filter_dilations,
|
||||
std::vector<ck::index_t> input_left_pads,
|
||||
std::vector<ck::index_t> input_right_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> filter_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> output_spatial_lengths,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides,
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations,
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads,
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
@@ -1058,6 +1191,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<WeiDataType*>(p_wei_grid),
|
||||
static_cast<const OutDataType*>(p_out_grid),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
@@ -1086,7 +1220,7 @@ struct DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceConvNdBwdWeightNwcKxcNwk_Xdl_CShuffle"
|
||||
str << "DeviceGroupedConvBwdWeightGnwcGkxcGnwk_Xdl_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
Reference in New Issue
Block a user