mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add grouped conv bwd data wmma (#950)
* Add grouped conv bwd data wmma
* Fix copyrights
* Add instances with smaller NPerBlock
* Update interface test
* Minor stylistic fixes
* Minor stylistic fixes
[ROCm/composable_kernel commit: cb53874002]
This commit is contained in:
@@ -1,15 +1,27 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_data)
|
||||
add_example_executable(example_grouped_conv_bwd_data_fp16 grouped_conv_bwd_data_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_bwd_data_xdl_fp16 grouped_conv_bwd_data_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_fp16)
|
||||
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_xdl_fp16)
|
||||
endif()
|
||||
add_example_executable(example_grouped_conv_bwd_data_bias_relu_fp16 grouped_conv_bwd_data_bias_relu_fp16.cpp)
|
||||
add_example_executable(example_grouped_conv_bwd_data_bias_relu_xdl_fp16 grouped_conv_bwd_data_bias_relu_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_fp16)
|
||||
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_bias_relu_xdl_fp16)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_data)
|
||||
add_example_executable(example_grouped_conv_bwd_data_wmma_fp16 grouped_conv_bwd_data_wmma_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_bwd_data example_grouped_conv_bwd_data_wmma_fp16)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
|
||||
@@ -10,7 +10,6 @@
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
using OutDataType = FP16;
|
||||
@@ -0,0 +1,34 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
using OutDataType = FP16;
|
||||
using WeiDataType = FP16;
|
||||
using AccDataType = FP32;
|
||||
using CShuffleDataType = FP16;
|
||||
using DsDataType = ck::Tuple<>;
|
||||
using InDataType = FP16;
|
||||
|
||||
using OutLayout = ck::tensor_layout::convolution::GNHWK;
|
||||
using WeiLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using InLayout = ck::tensor_layout::convolution::GNHWC;
|
||||
|
||||
using OutElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using InElementOp = PassThrough;
|
||||
|
||||
// clang-format off
|
||||
using DeviceConvInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
//| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< 2,OutLayout, WeiLayout, DsLayout, InLayout, OutDataType, WeiDataType, AccDataType, CShuffleDataType, DsDataType, InDataType, OutElementOp, WeiElementOp, InElementOp, ConvBwdDataDefault, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>;
|
||||
// clang-format on
|
||||
|
||||
#include "run_grouped_conv_bwd_data_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return run_grouped_conv_bwd_data_example(argc, argv); }
|
||||
@@ -1,6 +1,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp"
|
||||
#include "common.hpp"
|
||||
|
||||
using OutDataType = FP16;
|
||||
@@ -0,0 +1,879 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_data_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
// Conv backward data multiple D:
|
||||
// input : output image A: [G, N, K, Ho, Wo]
|
||||
// input : weight B: [G, K, C, Y, X],
|
||||
// input : D0, D1, ... : [G, N, K, Ho, Wo]
|
||||
// output : input image E: [G, N, C, Hi, Wi]
|
||||
// C = a_op(A) * b_op(B)
|
||||
// E = cde_op(C, D0, D1, ...)
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout, // output image
|
||||
typename BLayout, // weight
|
||||
typename DsLayout, // bias
|
||||
typename ELayout, // input image
|
||||
typename ADataType, // output image
|
||||
typename BDataType, // weight
|
||||
typename AccDataType,
|
||||
typename CShuffleDataType,
|
||||
typename DsDataType, // bias
|
||||
typename EDataType, // input image
|
||||
typename AElementwiseOp, // output image
|
||||
typename BElementwiseOp, // weight
|
||||
typename CDEElementwiseOp, // C, bias, and input image
|
||||
ConvolutionBackwardDataSpecialization ConvBackwardDataSpecialization,
|
||||
ck::index_t BlockSize,
|
||||
ck::index_t MPerBlock,
|
||||
ck::index_t NPerBlock,
|
||||
ck::index_t K0PerBlock,
|
||||
ck::index_t K1,
|
||||
ck::index_t MPerWMMA,
|
||||
ck::index_t NPerWMMA,
|
||||
ck::index_t MRepeat,
|
||||
ck::index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_AK1,
|
||||
bool ABlockLdsExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_BK1,
|
||||
bool BBlockLdsExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1>
|
||||
struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
: public DeviceGroupedConvBwdDataMultipleD<NDimSpatial,
|
||||
ALayout, // output image
|
||||
BLayout, // weight
|
||||
DsLayout, // bias
|
||||
ELayout, // input image
|
||||
ADataType, // output image
|
||||
BDataType, // weight
|
||||
DsDataType, // bias
|
||||
EDataType, // input image
|
||||
AElementwiseOp,
|
||||
BElementwiseOp,
|
||||
CDEElementwiseOp>
|
||||
{
|
||||
// TODO: Extend support for more spatial dimensions.
|
||||
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
|
||||
"wrong! only implemented for 2D and 3D now");
|
||||
|
||||
using DeviceOp = DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle;
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// TODO: Add support for different A and B data types.
|
||||
using ABDataType = ADataType;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr index_t KPerBlock = K0PerBlock * K1;
|
||||
|
||||
static constexpr auto transform_conv_to_gemm =
|
||||
TransformConvBwdDataToGemm_v1<NDimSpatial,
|
||||
ConvBackwardDataSpecialization,
|
||||
K1,
|
||||
K1,
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
KPerBlock,
|
||||
true /* DoPadGemmM */,
|
||||
true /* DoPadGemmN */>{};
|
||||
|
||||
static auto GetDummyABDsEGridDescriptor()
|
||||
{
|
||||
const std::array<index_t, NDimSpatial + 3> dummy_tensor_lengths = {1};
|
||||
const std::array<index_t, NDimSpatial + 3> dummy_tensor_strides = {1};
|
||||
const std::array<index_t, NDimSpatial> dummy_spatial_lengths = {1};
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
|
||||
const auto ds_grid_desc_m_n = generate_tuple(
|
||||
[&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
return transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
},
|
||||
Number<NumDTensor>{});
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_tensor_lengths,
|
||||
dummy_tensor_strides,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths,
|
||||
dummy_spatial_lengths);
|
||||
|
||||
return make_tuple(
|
||||
a_grid_desc_ak0_m_ak1, b_grid_desc_bk0_n_bk1, ds_grid_desc_m_n, e_grid_desc_m_n);
|
||||
}
|
||||
|
||||
// desc
|
||||
using ABDsEGridDesc = decltype(GetDummyABDsEGridDescriptor());
|
||||
|
||||
using AGridDesc_AK0_M_AK1 = remove_cvref_t<tuple_element_t<0, ABDsEGridDesc>>;
|
||||
using BGridDesc_BK0_N_BK1 = remove_cvref_t<tuple_element_t<1, ABDsEGridDesc>>;
|
||||
using DsGridDesc_M_N = remove_cvref_t<tuple_element_t<2, ABDsEGridDesc>>;
|
||||
using EGridDesc_M_N = remove_cvref_t<tuple_element_t<3, ABDsEGridDesc>>;
|
||||
|
||||
// GridwiseGemm
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CShuffleDataType,
|
||||
DsDataType,
|
||||
EDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_AK0_M_AK1,
|
||||
BGridDesc_BK0_N_BK1,
|
||||
DsGridDesc_M_N,
|
||||
EGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
AElementwiseOp,
|
||||
BElementwiseOp,
|
||||
CDEElementwiseOp,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
// ThreadCluster Family
|
||||
BlockSize,
|
||||
ABlockTransferThreadClusterLengths_AK0_M_AK1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_AK1,
|
||||
false,
|
||||
ABlockLdsExtraM,
|
||||
BBlockTransferThreadClusterLengths_BK0_N_BK1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_BK1,
|
||||
false,
|
||||
BBlockLdsExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CDEShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CDEShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumGemmKPrefetchStage,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
DsGridDesc_M_N{}));
|
||||
using EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
EGridDesc_M_N{}));
|
||||
|
||||
// Argument
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const void* p_a, // output image
|
||||
const void* p_b, // weight
|
||||
const std::array<const void*, NumDTensor>& p_ds, // bias
|
||||
void* p_e, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_c_wis_lengths,
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOp& a_element_op,
|
||||
const BElementwiseOp& b_element_op,
|
||||
const CDEElementwiseOp& cde_element_op)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b)},
|
||||
p_ds_grid_{},
|
||||
p_e_grid_{static_cast<EDataType*>(p_e)},
|
||||
num_group_{a_g_n_k_wos_lengths[0]},
|
||||
a_element_op_{a_element_op},
|
||||
b_element_op_{b_element_op},
|
||||
cde_element_op_{cde_element_op},
|
||||
a_g_n_k_wos_lengths_{a_g_n_k_wos_lengths},
|
||||
a_g_n_k_wos_strides_{a_g_n_k_wos_strides},
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
b_g_k_c_xs_strides_{b_g_k_c_xs_strides},
|
||||
ds_g_n_c_wis_lengths_{ds_g_n_c_wis_lengths},
|
||||
ds_g_n_c_wis_strides_{ds_g_n_c_wis_strides},
|
||||
e_g_n_c_wis_lengths_{e_g_n_c_wis_lengths},
|
||||
e_g_n_c_wis_strides_{e_g_n_c_wis_strides},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
conv_filter_dilations_{conv_filter_dilations},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
{
|
||||
// populate Ds pointer
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DDataType = remove_cvref_t<tuple_element_t<i.value, DsDataType>>;
|
||||
|
||||
p_ds_grid_(i) = static_cast<const DDataType*>(p_ds[i]);
|
||||
});
|
||||
|
||||
// A/B/Ds/E Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = a_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = b_g_k_c_xs_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ = e_g_n_c_wis_strides[0];
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
compute_ptr_offset_of_batch_.BatchStrideDs_(i) = ds_g_n_c_wis_strides[i][0];
|
||||
});
|
||||
|
||||
static constexpr auto NonSpatialDimsNum = Number<3>{};
|
||||
|
||||
static constexpr auto DIdx = Number<NonSpatialDimsNum>{};
|
||||
static constexpr auto HIdx =
|
||||
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
|
||||
static constexpr auto WIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
|
||||
: Number<NonSpatialDimsNum + 2>{};
|
||||
|
||||
static constexpr auto ZIdx = Number<NonSpatialDimsNum>{};
|
||||
static constexpr auto YIdx =
|
||||
NDimSpatial == 2 ? Number<NonSpatialDimsNum>{} : Number<NonSpatialDimsNum + 1>{};
|
||||
static constexpr auto XIdx = NDimSpatial == 2 ? Number<NonSpatialDimsNum + 1>{}
|
||||
: Number<NonSpatialDimsNum + 2>{};
|
||||
|
||||
// problem definition
|
||||
const index_t Z = b_g_k_c_xs_lengths[ZIdx];
|
||||
const index_t Y = b_g_k_c_xs_lengths[YIdx];
|
||||
const index_t X = b_g_k_c_xs_lengths[XIdx];
|
||||
|
||||
const index_t ConvStrideD = conv_filter_strides[DIdx - NonSpatialDimsNum];
|
||||
const index_t ConvStrideH = conv_filter_strides[HIdx - NonSpatialDimsNum];
|
||||
const index_t ConvStrideW = conv_filter_strides[WIdx - NonSpatialDimsNum];
|
||||
|
||||
const index_t ConvDilationD = conv_filter_dilations[DIdx - NonSpatialDimsNum];
|
||||
const index_t ConvDilationH = conv_filter_dilations[HIdx - NonSpatialDimsNum];
|
||||
const index_t ConvDilationW = conv_filter_dilations[WIdx - NonSpatialDimsNum];
|
||||
|
||||
const auto GcdStrideDilationD = math::gcd(ConvStrideD, ConvDilationD);
|
||||
const auto GcdStrideDilationH = math::gcd(ConvStrideH, ConvDilationH);
|
||||
const auto GcdStrideDilationW = math::gcd(ConvStrideW, ConvDilationW);
|
||||
|
||||
const auto ZTilde = NDimSpatial == 3 ? ConvStrideD / GcdStrideDilationD : 1;
|
||||
const auto YTilde = ConvStrideH / GcdStrideDilationH;
|
||||
const auto XTilde = ConvStrideW / GcdStrideDilationW;
|
||||
|
||||
for(index_t i_ztilde = 0; i_ztilde < ZTilde; ++i_ztilde)
|
||||
{
|
||||
|
||||
for(index_t i_ytilde = 0; i_ytilde < YTilde; ++i_ytilde)
|
||||
{
|
||||
for(index_t i_xtilde = 0; i_xtilde < XTilde; ++i_xtilde)
|
||||
{
|
||||
// check slice is valid
|
||||
const auto ZDotSlice =
|
||||
NDimSpatial == 3 ? math::integer_divide_ceil(Z - i_ztilde, ZTilde) : 1;
|
||||
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
|
||||
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
|
||||
|
||||
if(YDotSlice * XDotSlice * ZDotSlice <= 0)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
std::array<index_t, NDimSpatial> tildes;
|
||||
if constexpr(NDimSpatial == 2)
|
||||
{
|
||||
tildes = {i_ytilde, i_xtilde};
|
||||
}
|
||||
else if constexpr(NDimSpatial == 3)
|
||||
{
|
||||
tildes = {i_ztilde, i_ytilde, i_xtilde};
|
||||
}
|
||||
else
|
||||
{
|
||||
throw std::runtime_error("wrong! only implemented for 2D and 3D now");
|
||||
}
|
||||
|
||||
const auto a_grid_desc_ak0_m_ak1 =
|
||||
transform_conv_to_gemm.template MakeADescriptor_AK0_M_AK1<ALayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
|
||||
const auto b_grid_desc_bk0_n_bk1 =
|
||||
transform_conv_to_gemm.template MakeBDescriptor_BK0_N_BK1<BLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
|
||||
DsGridDesc_M_N ds_grid_desc_m_n;
|
||||
|
||||
// populate Ds desc
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
ds_grid_desc_m_n(i) =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<DLayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_c_wis_lengths[i],
|
||||
ds_g_n_c_wis_strides[i],
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
});
|
||||
|
||||
const auto e_grid_desc_m_n =
|
||||
transform_conv_to_gemm.template MakeCDescriptor_M_N<ELayout>(
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
tildes);
|
||||
|
||||
// for check validity
|
||||
ds_grid_desc_m_n_container_.push_back(ds_grid_desc_m_n);
|
||||
e_grid_desc_m_n_container_.push_back(e_grid_desc_m_n);
|
||||
|
||||
// desc for blockwise copy
|
||||
a_grid_desc_ak0_m_ak1_container_.push_back(a_grid_desc_ak0_m_ak1);
|
||||
b_grid_desc_bk0_n_bk1_container_.push_back(b_grid_desc_bk0_n_bk1);
|
||||
|
||||
// block-to-e-tile-map
|
||||
auto block_2_ctile_map = GridwiseGemm::MakeDefaultBlock2CTileMap(
|
||||
e_grid_desc_m_n, 1 /* M01 */, 1 /* N01 */);
|
||||
|
||||
block_2_ctile_map_container_.push_back(block_2_ctile_map);
|
||||
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n));
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_container_.push_back(
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
{
|
||||
for(std::size_t i = 0; i < a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
{
|
||||
std::cout << "a_grid_desc_ak0_m_ak1_container_"
|
||||
<< a_grid_desc_ak0_m_ak1_container_[i] << std::endl;
|
||||
|
||||
std::cout << "b_grid_desc_bk0_n_bk1_container_"
|
||||
<< b_grid_desc_bk0_n_bk1_container_[i] << std::endl;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto j) {
|
||||
std::cout << "ds_grid_desc_mblock_mperblock_nblock_nperblock_container_"
|
||||
<< ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i][j]
|
||||
<< std::endl;
|
||||
});
|
||||
|
||||
std::cout << "e_grid_desc_mblock_mperblock_nblock_nperblock_container_"
|
||||
<< e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i]
|
||||
<< std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
// pointers
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
typename GridwiseGemm::DsGridPointer p_ds_grid_;
|
||||
EDataType* p_e_grid_;
|
||||
|
||||
// tensor descriptor for problem definition
|
||||
index_t num_group_;
|
||||
std::vector<DsGridDesc_M_N> ds_grid_desc_m_n_container_;
|
||||
std::vector<EGridDesc_M_N> e_grid_desc_m_n_container_;
|
||||
|
||||
// tensor descriptor for block-wise copy
|
||||
std::vector<AGridDesc_AK0_M_AK1> a_grid_desc_ak0_m_ak1_container_;
|
||||
std::vector<BGridDesc_BK0_N_BK1> b_grid_desc_bk0_n_bk1_container_;
|
||||
std::vector<DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
ds_grid_desc_mblock_mperblock_nblock_nperblock_container_;
|
||||
std::vector<EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>
|
||||
e_grid_desc_mblock_mperblock_nblock_nperblock_container_;
|
||||
|
||||
// block-to-e-tile map
|
||||
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor> compute_ptr_offset_of_batch_;
|
||||
|
||||
// element-wise op
|
||||
AElementwiseOp a_element_op_;
|
||||
BElementwiseOp b_element_op_;
|
||||
CDEElementwiseOp cde_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> a_g_n_k_wos_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> b_g_k_c_xs_strides_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_lengths_;
|
||||
std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor> ds_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_lengths_;
|
||||
std::array<index_t, NDimSpatial + 3> e_g_n_c_wis_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> conv_filter_dilations_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
arg.Print();
|
||||
}
|
||||
|
||||
float ave_time = 0;
|
||||
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
{
|
||||
const index_t grid_size = arg.block_2_ctile_map_container_[i].CalculateGridSize(
|
||||
arg.e_grid_desc_m_n_container_[i]) *
|
||||
arg.num_group_;
|
||||
|
||||
const auto GemmK = arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I0) *
|
||||
arg.a_grid_desc_ak0_m_ak1_container_[i].GetLength(I2);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
EDataType,
|
||||
AElementwiseOp,
|
||||
BElementwiseOp,
|
||||
CDEElementwiseOp,
|
||||
DeviceOp::AGridDesc_AK0_M_AK1,
|
||||
DeviceOp::BGridDesc_BK0_N_BK1,
|
||||
typename GridwiseGemm::DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
typename GridwiseGemm::EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<NumDTensor>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(
|
||||
stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
arg.p_ds_grid_,
|
||||
arg.p_e_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.cde_element_op_,
|
||||
arg.a_g_n_k_wos_lengths_[0], // Group count
|
||||
arg.a_grid_desc_ak0_m_ak1_container_[i],
|
||||
arg.b_grid_desc_bk0_n_bk1_container_[i],
|
||||
arg.ds_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
|
||||
arg.e_grid_desc_mblock_mperblock_nblock_nperblock_container_[i],
|
||||
arg.block_2_ctile_map_container_[i],
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
};
|
||||
|
||||
if(GridwiseGemm::CalculateHasMainKBlockLoop(GemmK))
|
||||
{
|
||||
ave_time += launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
ave_time += launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
return ave_time;
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
const index_t ConvK = arg.b_g_k_c_xs_lengths_[1];
|
||||
const index_t ConvC = arg.b_g_k_c_xs_lengths_[2];
|
||||
|
||||
// Specialization
|
||||
if constexpr(ConvBackwardDataSpecialization ==
|
||||
ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's a 1x1 convolution with stride=1 and no padding
|
||||
for(int i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
if(!(arg.b_g_k_c_xs_lengths_[3 + i] == 1 && arg.conv_filter_strides_[i] == 1 &&
|
||||
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// vector load for A matrix from global memory to LDS
|
||||
if constexpr(is_same_v<ALayout, tensor_layout::convolution::GNHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::GNDHWK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NHWGK> ||
|
||||
is_same_v<ALayout, tensor_layout::convolution::NDHWGK>)
|
||||
{
|
||||
if(!(ABlockTransferSrcVectorDim == 2 && ConvK % ABlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector load for B matrix from global memory to LDS
|
||||
if constexpr(is_same_v<BLayout, tensor_layout::convolution::GKYXC> ||
|
||||
is_same_v<BLayout, tensor_layout::convolution::GKZYXC>)
|
||||
{
|
||||
if(!(BBlockTransferSrcVectorDim == 1 && ConvC % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store for Ds
|
||||
bool ds_valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
using DLayout = remove_cvref_t<tuple_element_t<i.value, DsLayout>>;
|
||||
|
||||
if constexpr(is_same_v<DLayout, tensor_layout::convolution::GNHWC> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::GNDHWC> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::NHWGC> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::NDHWGC> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::G_NHW_C> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::GC> ||
|
||||
is_same_v<DLayout, tensor_layout::convolution::G_C>)
|
||||
{
|
||||
// vector load D matrix from global memory
|
||||
if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
ds_valid = false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
ds_valid = false;
|
||||
}
|
||||
});
|
||||
|
||||
if(!ds_valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store for E
|
||||
if constexpr(is_same_v<ELayout, tensor_layout::convolution::GNHWC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::GNDHWC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NHWGC> ||
|
||||
is_same_v<ELayout, tensor_layout::convolution::NDHWGC>)
|
||||
{
|
||||
// vector store C matrix into global memory
|
||||
if(!(ConvC % CDEShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
for(std::size_t i = 0; i < arg.a_grid_desc_ak0_m_ak1_container_.size(); i++)
|
||||
{
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_ak0_m_ak1_container_[i],
|
||||
arg.b_grid_desc_bk0_n_bk1_container_[i],
|
||||
arg.ds_grid_desc_m_n_container_[i],
|
||||
arg.e_grid_desc_m_n_container_[i],
|
||||
arg.block_2_ctile_map_container_[i]))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const void* p_a, // output image
|
||||
const void* p_b, // weight
|
||||
const std::array<const void*, NumDTensor>& p_ds, // bias
|
||||
void* p_e, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_c_wis_lengths, // bias
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_c_wis_strides, // bias
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOp& a_element_op,
|
||||
const BElementwiseOp& b_element_op,
|
||||
const CDEElementwiseOp& cde_element_op)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_c_wis_lengths,
|
||||
ds_g_n_c_wis_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument> MakeArgumentPointer(
|
||||
const void* p_a, // output image
|
||||
const void* p_b, // weight
|
||||
const std::array<const void*, NumDTensor>& p_ds, // bias
|
||||
void* p_e, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_lengths, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_k_wos_strides, // output image
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides, // weight
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_c_wis_lengths, // bias
|
||||
const std::array<std::array<index_t, NDimSpatial + 3>, NumDTensor>&
|
||||
ds_g_n_c_wis_strides, // bias
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_lengths, // input image
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_c_wis_strides, // input image
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOp& a_element_op,
|
||||
const BElementwiseOp& b_element_op,
|
||||
const CDEElementwiseOp& cde_element_op) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
p_ds,
|
||||
p_e,
|
||||
a_g_n_k_wos_lengths,
|
||||
a_g_n_k_wos_strides,
|
||||
b_g_k_c_xs_lengths,
|
||||
b_g_k_c_xs_strides,
|
||||
ds_g_n_c_wis_lengths,
|
||||
ds_g_n_c_wis_strides,
|
||||
e_g_n_c_wis_lengths,
|
||||
e_g_n_c_wis_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< KPerBlock << ", "
|
||||
<< getConvBackwardDataSpecializationString(ConvBackwardDataSpecialization) << ", "
|
||||
<< K1 << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferSrcScalarPerVector
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_data_to_gemm_v1.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
@@ -24,51 +25,6 @@ namespace device {
|
||||
|
||||
namespace {
|
||||
|
||||
template <index_t NumDTensor>
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs,
|
||||
index_t BatchStrideE)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
BatchStrideE_(BatchStrideE)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
Array<long_index_t, NumDTensor> ds_offset;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
};
|
||||
|
||||
/*
|
||||
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
|
||||
*
|
||||
@@ -257,7 +213,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
BElementwiseOp,
|
||||
CDEElementwiseOp>
|
||||
{
|
||||
// FIXME
|
||||
// TODO: Extend support for more spatial dimensions.
|
||||
static_assert(NDimSpatial == 2 || NDimSpatial == 3,
|
||||
"wrong! only implemented for 2D and 3D now");
|
||||
|
||||
@@ -265,7 +221,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
|
||||
|
||||
static constexpr index_t NumDTensor = DsDataType::Size();
|
||||
|
||||
// TODO make A/B datatype different
|
||||
// TODO: Add support for different A and B data types.
|
||||
using ABDataType = ADataType;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_multiple_d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
@@ -29,51 +30,6 @@ namespace device {
|
||||
|
||||
namespace {
|
||||
|
||||
template <index_t NumDTensor>
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs,
|
||||
index_t BatchStrideE)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
BatchStrideE_(BatchStrideE)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
Array<long_index_t, NumDTensor> ds_offset;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
};
|
||||
|
||||
/*
|
||||
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
|
||||
*
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
@@ -27,55 +28,6 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
namespace {
|
||||
|
||||
template <index_t NumDTensor>
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs,
|
||||
index_t BatchStrideE)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
BatchStrideE_(BatchStrideE)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
Array<long_index_t, NumDTensor> ds_offset;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//
|
||||
// @brief Device Convolution operation.
|
||||
//
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
#include "ck/host_utility/io.hpp"
|
||||
@@ -29,51 +30,6 @@ namespace device {
|
||||
|
||||
namespace {
|
||||
|
||||
template <index_t NumDTensor>
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs,
|
||||
index_t BatchStrideE)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
BatchStrideE_(BatchStrideE)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
Array<long_index_t, NumDTensor> ds_offset;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
};
|
||||
|
||||
/*
|
||||
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
|
||||
*
|
||||
|
||||
@@ -0,0 +1,59 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NumDTensor>
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
ComputePtrOffsetOfStridedBatch() = default;
|
||||
|
||||
ComputePtrOffsetOfStridedBatch(index_t BatchStrideA,
|
||||
index_t BatchStrideB,
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs,
|
||||
index_t BatchStrideE)
|
||||
: BatchStrideA_(BatchStrideA),
|
||||
BatchStrideB_(BatchStrideB),
|
||||
BatchStrideDs_(BatchStrideDs),
|
||||
BatchStrideE_(BatchStrideE)
|
||||
{
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto GetDsPtrOffset(index_t g_idx) const
|
||||
{
|
||||
Array<long_index_t, NumDTensor> ds_offset;
|
||||
static_for<0, NumDTensor, 1>{}(
|
||||
[&](auto i) { ds_offset(i) = g_idx * static_cast<long_index_t>(BatchStrideDs_[i]); });
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,118 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdDataDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdData1x1S1P0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <index_t NDSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename DsDatatype,
|
||||
typename CDEElementOp,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_wmma_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
// blocksize=256
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 256, 64, 256, 8, 8, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 256, 128, 256, 8, 8, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 256, 128, 64, 8, 8, 16, 16, 4, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// blocksize=128
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 128, 128, 8, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 32, 256, 8, 8, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// blocksize=64
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 64, 64, 8, 8, 16, 16, 2, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 32, 128, 8, 8, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
// blocksize=32
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 16, 64, 8, 8, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 64, 32, 8, 8, 16, 16, 4, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 32, 32, 8, 8, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 16, 32, 8, 8, 16, 16, 1, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
typename DsDatatype,
|
||||
typename CDEElementOp,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_wmma_i8_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 64, 64, 4, 16, 16, 16, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 16, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 16, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
// blocksize=256
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 256, 64, 256, 8, 16, 16, 16, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
// blocksize=128
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 64, 256, 8, 16, 16, 16, 2, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 64, 128, 8, 16, 16, 16, 2, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 128, 256, 8, 16, 16, 16, 4, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 32, 256, 8, 16, 16, 16, 1, 8, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 128, 256, 128, 8, 16, 16, 16, 8, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// blocksize=64
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 32, 128, 8, 16, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 64, 128, 8, 16, 16, 16, 2, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 32, 128, 8, 16, 16, 16, 1, 8, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 64, 32, 64, 8, 16, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
// blocksize=32
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 16, 64, 8, 16, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 64, 64, 8, 16, 16, 16, 4, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 32, 32, 8, 16, 16, 16, 2, 2, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<NDSpatial, ALayout, BLayout, DsLayout, ELayout, I8, I8, I32, I8, Empty_Tuple, I8, PassThrough, PassThrough, CDEElementOp, ConvSpec, 32, 16, 64, 8, 16, 16, 16, 1, 4, S<2, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 16, 1, 1, 1, S<1, 16, 1, 2>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -30,6 +30,34 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
@@ -61,6 +89,35 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
@@ -75,6 +132,35 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
@@ -106,6 +192,35 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
// conv3d backward data
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
@@ -121,6 +236,34 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
@@ -152,6 +295,35 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
@@ -166,6 +338,34 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
@@ -197,6 +397,35 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
template <ck::index_t NumDimSpatial,
|
||||
typename OutLayout,
|
||||
typename WeiLayout,
|
||||
@@ -247,6 +476,10 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -263,6 +496,15 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
@@ -273,6 +515,10 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<OutDataType, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -289,6 +535,15 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
@@ -304,6 +559,10 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -321,6 +580,16 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
@@ -332,6 +601,10 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -349,6 +622,16 @@ struct DeviceOperationInstanceFactory<
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
add_instance_library(device_grouped_conv2d_bwd_data_instance
|
||||
device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
|
||||
device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
|
||||
device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
)
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_gnhwc_gkyxc_gnhwk_i8_instance.cpp
|
||||
wmma/device_grouped_conv2d_bwd_data_wmma_nhwgc_gkyxc_nhwgk_i8_instance.cpp)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_f16_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdData1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_f16_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_i8_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdData1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_gnhwk_gkyxc_gnhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_i8_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_f16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdData1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_f16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_i8_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdData1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_bwd_data_wmma_nhwgk_gkyxc_nhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_i8_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,8 +1,17 @@
|
||||
add_instance_library(device_grouped_conv3d_bwd_data_instance
|
||||
device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
|
||||
device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_f16_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdData1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_f16_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_i8_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdData1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_gndhwk_gkzyxc_gndhwc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_i8_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_f16_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdData1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_f16_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_i8_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdData1x1S1P0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_data_wmma_ndhwgk_gkzyxc_ndhwgc_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
int8_t,
|
||||
int8_t,
|
||||
Empty_Tuple,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_wmma_i8_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
Empty_Tuple,
|
||||
PassThrough,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,6 +1,19 @@
|
||||
if(GPU_TARGETS MATCHES "gfx908" OR GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940")
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance)
|
||||
endif()
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_xdl.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance)
|
||||
set(target 1)
|
||||
endif()
|
||||
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data test_grouped_convnd_bwd_data.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_interface test_grouped_convnd_bwd_data_interface_wmma.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_interface PRIVATE utility device_grouped_conv2d_bwd_data_instance)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -51,16 +51,20 @@ using namespace ck::tensor_layout::convolution;
|
||||
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<int8_t, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<float, NHWGK, GKYXC, NHWGC>,
|
||||
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
|
||||
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>;
|
||||
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>,
|
||||
std::tuple<int8_t, NHWGK, GKYXC, NHWGC>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<int8_t, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<float, NDHWGK, GKZYXC, NDHWGC>,
|
||||
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
|
||||
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>>;
|
||||
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>,
|
||||
std::tuple<int8_t, NDHWGK, GKZYXC, NDHWGC>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdData2d : public TestGroupedConvndBwdData<Tuple>
|
||||
|
||||
@@ -0,0 +1,178 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_wmma_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using DataType = ck::half_t;
|
||||
using AccDataType = float;
|
||||
using Pass = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
using ConvBackwardDataSpecialization =
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization;
|
||||
|
||||
static constexpr auto ConvBwdDataDefault = ConvBackwardDataSpecialization::Default;
|
||||
static constexpr auto Filter1x1Stride1Pad0 = ConvBackwardDataSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <typename Tuple, ConvBackwardDataSpecialization ConvSpec>
|
||||
class TestGroupedConvndBwdData : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
static constexpr ck::index_t NDimSpatial = 2;
|
||||
|
||||
using OutLayout = std::tuple_element_t<0, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<1, Tuple>;
|
||||
using InLayout = std::tuple_element_t<2, Tuple>;
|
||||
|
||||
// clang-format off
|
||||
using GroupedConvBwdDataDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
//| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| DsData| EData| A| B| CDE| ConvForward| Block| MPer| NPer| K0Per| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl|
|
||||
//| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
< NDimSpatial,OutLayout, WeiLayout, ck::Tuple<>, InLayout, DataType, DataType, AccDataType, DataType, ck::Tuple<>, DataType, Pass, Pass, Pass, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>;
|
||||
// clang-format on
|
||||
|
||||
ck::utils::conv::ConvParam conv_param;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
bool Run()
|
||||
{
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> out_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> out_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> wei_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> wei_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> in_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> in_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto copy = [](const auto& x, auto& y) { ck::ranges::copy(x, y.begin()); };
|
||||
|
||||
copy(out_g_n_k_wos_desc.GetLengths(), out_lengths);
|
||||
copy(out_g_n_k_wos_desc.GetStrides(), out_strides);
|
||||
copy(wei_g_k_c_xs_desc.GetLengths(), wei_lengths);
|
||||
copy(wei_g_k_c_xs_desc.GetStrides(), wei_strides);
|
||||
copy(in_g_n_c_wis_desc.GetLengths(), in_lengths);
|
||||
copy(in_g_n_c_wis_desc.GetStrides(), in_strides);
|
||||
copy(conv_param.conv_filter_strides_, conv_filter_strides);
|
||||
copy(conv_param.conv_filter_dilations_, conv_filter_dilations);
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
auto conv = GroupedConvBwdDataDeviceInstance{};
|
||||
|
||||
auto argument = conv.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
std::array<const void*, 0>{},
|
||||
nullptr,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
Pass{},
|
||||
Pass{},
|
||||
Pass{});
|
||||
return conv.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
using GNHWC = ck::tensor_layout::convolution::GNHWC;
|
||||
using NHWGC = ck::tensor_layout::convolution::NHWGC;
|
||||
|
||||
using GKYXC = ck::tensor_layout::convolution::GKYXC;
|
||||
|
||||
using GNHWK = ck::tensor_layout::convolution::GNHWK;
|
||||
using NHWGK = ck::tensor_layout::convolution::NHWGK;
|
||||
|
||||
using KernelTypes =
|
||||
::testing::Types<std::tuple<GNHWK, GKYXC, GNHWC>, std::tuple<NHWGK, GKYXC, NHWGC>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataDefault : public TestGroupedConvndBwdData<Tuple, ConvBwdDataDefault>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataFilter1x1
|
||||
: public TestGroupedConvndBwdData<Tuple, Filter1x1Stride1Pad0>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataDefault, KernelTypes);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataFilter1x1, KernelTypes);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataFilter1x1, SpecializationCheck)
|
||||
{
|
||||
// Check filter 3,3 instead of 1,1
|
||||
this->conv_param = {2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
bool is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Check strides 2,2 instead of 1,1
|
||||
this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Check with pad
|
||||
this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Supported version
|
||||
this->conv_param = {2, 2, 4, 192, 192, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_TRUE(is_supported);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataDefault, VectorLoadCheck)
|
||||
{
|
||||
// vector load for A
|
||||
this->conv_param = {2, 2, 128, 129, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
bool is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
// vector load for B, E, Ds
|
||||
this->conv_param = {2, 2, 128, 128, 257, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
}
|
||||
Reference in New Issue
Block a user