mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Add grouped conv bwd weight wmma (#985)
* Add grouped conv bwd weight wmma
* Update README, changelog, profiler
* Minor fixes
* Fix grouped conv bwd wei dl kernel
* Minor fixes
* Minor stylistic fixes
[ROCm/composable_kernel commit: 16d7c4d2f7]
This commit is contained in:
@@ -14,7 +14,7 @@ None
|
||||
### Additions
|
||||
- Added an image to a column kernel (#867)
|
||||
- Added a column to an image kernel (#930)
|
||||
- Support for 3D grouped convolution forward on RDNA 3 GPUs (#935)
|
||||
- Support for 3D grouped convolution on RDNA 3 GPUs (#935, #950, #985)
|
||||
- Grouped convolution support for small K and C (#822 #879 #897)
|
||||
- Support for NHWGC (2D and 3D) grouped convolution backward weight (#769 #804)
|
||||
- Support for bf16/f32/f16 and NHWGC (2D and 3d) grouped convolution backward data (#757 #799)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_weight)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_xdl_fp16 grouped_conv_bwd_weight_xdl_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
@@ -18,6 +19,14 @@ foreach(gpu IN LISTS GPU_TARGETS)
|
||||
endif()
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
|
||||
add_custom_target(example_grouped_conv_bwd_weight)
|
||||
add_example_executable(example_grouped_conv_bwd_weight_wmma_fp16 grouped_conv_bwd_weight_wmma_fp16.cpp)
|
||||
if(result EQUAL 0)
|
||||
add_dependencies(example_grouped_conv_bwd_weight example_grouped_conv_bwd_weight_wmma_fp16)
|
||||
endif()
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
|
||||
@@ -46,25 +46,21 @@ struct CommonLayoutSetting
|
||||
using OutputLayout = OutputLay;
|
||||
};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
struct CommonLayoutSettingSelector;
|
||||
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
|
||||
template <>
|
||||
struct CommonLayoutSettingSelector<1> final : CommonLayoutSetting<ctl::GNWC, ctl::GKXC, ctl::GNWK>
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CommonLayoutSettingSelector<2> final
|
||||
: CommonLayoutSetting<ctl::GNHWC, ctl::GKYXC, ctl::GNHWK>
|
||||
{
|
||||
};
|
||||
|
||||
template <>
|
||||
struct CommonLayoutSettingSelector<3> final
|
||||
: CommonLayoutSetting<ctl::GNDHWC, ctl::GKZYXC, ctl::GNDHWK>
|
||||
template <ck::index_t NDimSpatial>
|
||||
struct CommonLayoutSettingSelector
|
||||
: CommonLayoutSetting<ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWC,
|
||||
ck::tensor_layout::convolution::GNHWC,
|
||||
ck::tensor_layout::convolution::GNDHWC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GKXC,
|
||||
ck::tensor_layout::convolution::GKYXC,
|
||||
ck::tensor_layout::convolution::GKZYXC>>,
|
||||
ck::tuple_element_t<NDimSpatial - 1,
|
||||
ck::Tuple<ck::tensor_layout::convolution::GNWK,
|
||||
ck::tensor_layout::convolution::GNHWK,
|
||||
ck::tensor_layout::convolution::GNDHWK>>>
|
||||
{
|
||||
};
|
||||
|
||||
@@ -84,10 +80,10 @@ struct ExecutionConfig final
|
||||
bool time_kernel = false;
|
||||
};
|
||||
|
||||
#define DefaultConvParam \
|
||||
ck::utils::conv::ConvParam \
|
||||
{ \
|
||||
2, 4, 1, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, { 1, 1 } \
|
||||
#define DefaultConvParam \
|
||||
ck::utils::conv::ConvParam \
|
||||
{ \
|
||||
3, 4, 1, 128, 256, {3, 3, 3}, {14, 14, 14}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, { 1, 1, 1 } \
|
||||
}
|
||||
|
||||
inline void print_help_msg()
|
||||
|
||||
@@ -76,4 +76,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExecutionConfig config;
|
||||
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, config, conv_param))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
|
||||
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
|
||||
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
|
||||
default: break;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
|
||||
using InDataType = F16;
|
||||
using WeiDataType = F16;
|
||||
using OutDataType = F16;
|
||||
using AccDataType = F32;
|
||||
|
||||
using InElementOp = PassThrough;
|
||||
using WeiElementOp = PassThrough;
|
||||
using OutElementOp = PassThrough;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using DeviceConvBwdWeightInstance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
|
||||
NDimSpatial,
|
||||
ck::tensor_layout::convolution::GNDHWC,
|
||||
ck::tensor_layout::convolution::GKZYXC,
|
||||
ck::tensor_layout::convolution::GNDHWK,
|
||||
InDataType, // InDataType
|
||||
WeiDataType, // WeiDataType
|
||||
OutDataType, // OutDataType
|
||||
AccDataType, // AccDataType
|
||||
InElementOp, // InElementwiseOperation
|
||||
WeiElementOp, // WeiElementwiseOperation
|
||||
OutElementOp, // OutElementwiseOperation
|
||||
ConvBwdWeightDefault, // ConvolutionBackwardWeightSpecialization
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
128, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
8, // K1
|
||||
16, // MPerWMMA
|
||||
16, // NPerWMMA
|
||||
4, // MRepeat
|
||||
2, // NRepeat
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
|
||||
S<0, 2, 1>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // ABlockTransferSrcAccessOrder
|
||||
1, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_AK1
|
||||
true, // ABlockLdsExtraM
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
|
||||
S<0, 2, 1>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<0, 2, 1>, // BBlockTransferSrcAccessOrder
|
||||
1, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_BK1
|
||||
true, // BBlockLdsExtraN
|
||||
4,
|
||||
2,
|
||||
S<1, 32, 1, 8>,
|
||||
1>;
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWeight<NDimSpatial,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementOp,
|
||||
WeiElementOp,
|
||||
OutElementOp>;
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExecutionConfig config;
|
||||
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, config, conv_param))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
|
||||
default: break;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
@@ -78,4 +78,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExecutionConfig config;
|
||||
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, config, conv_param))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
|
||||
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
|
||||
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
|
||||
default: break;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -77,4 +77,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExecutionConfig config;
|
||||
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, config, conv_param))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
|
||||
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
|
||||
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
|
||||
default: break;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -83,4 +83,23 @@ using HostConvBwdWeightInstance = ck::tensor_operation::host::ReferenceConvBwdWe
|
||||
|
||||
#include "run_grouped_conv_bwd_weight_example.inc"
|
||||
|
||||
int main(int argc, char* argv[]) { return !run_grouped_conv_bwd_weight_example(argc, argv); }
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
ExecutionConfig config;
|
||||
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, config, conv_param))
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1: return !run_grouped_conv_bwd_weight<1>(config, conv_param);
|
||||
case 2: return !run_grouped_conv_bwd_weight<2>(config, conv_param);
|
||||
case 3: return !run_grouped_conv_bwd_weight<3>(config, conv_param);
|
||||
default: break;
|
||||
}
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ template <ck::index_t NDimSpatial>
|
||||
bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
{
|
||||
// Dl op doesn't support split_k > 1
|
||||
// Dl and WMMA ops don't support split_k > 1
|
||||
constexpr ck::index_t split_k = 1;
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
@@ -143,23 +143,3 @@ bool run_grouped_conv_bwd_weight(const ExecutionConfig& config,
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
bool run_grouped_conv_bwd_weight_example(int argc, char* argv[])
|
||||
{
|
||||
ExecutionConfig config;
|
||||
ck::utils::conv::ConvParam conv_param = DefaultConvParam;
|
||||
|
||||
if(!parse_cmd_args(argc, argv, config, conv_param))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
switch(conv_param.num_dim_spatial_)
|
||||
{
|
||||
case 1: return run_grouped_conv_bwd_weight<1>(config, conv_param);
|
||||
case 2: return run_grouped_conv_bwd_weight<2>(config, conv_param);
|
||||
case 3: return run_grouped_conv_bwd_weight<3>(config, conv_param);
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -565,7 +565,7 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle<
|
||||
const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_dl_v1r3.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
|
||||
@@ -22,32 +23,6 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
namespace {
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatAB,
|
||||
typename FloatC,
|
||||
@@ -952,7 +927,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
|
||||
|
||||
// element-wise op
|
||||
OutElementwiseOperation a_element_op_;
|
||||
@@ -1024,7 +999,7 @@ struct DeviceGroupedConvBwdWeight_Dl : public DeviceGroupedConvBwdWeight<NDimSpa
|
||||
remove_reference_t<DeviceOp::BGridDesc_B_K0_N0_N1_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_M0_M10_M11_N0_N10_N11>,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
ComputePtrOffsetOfStridedBatch<I0>,
|
||||
has_main_loop,
|
||||
has_double_loop>;
|
||||
|
||||
|
||||
@@ -0,0 +1,877 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <iostream>
|
||||
#include <numeric>
|
||||
#include <sstream>
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor.hpp"
|
||||
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename InLayout,
|
||||
typename WeiLayout,
|
||||
typename OutLayout,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType,
|
||||
typename InElementwiseOperation,
|
||||
typename WeiElementwiseOperation,
|
||||
typename OutElementwiseOperation,
|
||||
ConvolutionBackwardWeightSpecialization ConvBackwardWeightSpecialization,
|
||||
index_t BlockSize,
|
||||
index_t MPerBlock,
|
||||
index_t NPerBlock,
|
||||
index_t K0PerBlock,
|
||||
index_t K1,
|
||||
index_t MPerWMMA,
|
||||
index_t NPerWMMA,
|
||||
index_t MRepeat,
|
||||
index_t NRepeat,
|
||||
typename ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
typename ABlockTransferThreadClusterArrangeOrder,
|
||||
typename ABlockTransferSrcAccessOrder,
|
||||
index_t ABlockTransferSrcVectorDim,
|
||||
index_t ABlockTransferSrcScalarPerVector,
|
||||
index_t ABlockTransferDstScalarPerVector_K1,
|
||||
bool ABlockLdsAddExtraM,
|
||||
typename BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
typename BBlockTransferThreadClusterArrangeOrder,
|
||||
typename BBlockTransferSrcAccessOrder,
|
||||
index_t BBlockTransferSrcVectorDim,
|
||||
index_t BBlockTransferSrcScalarPerVector,
|
||||
index_t BBlockTransferDstScalarPerVector_K1,
|
||||
bool BBlockLdsAddExtraN,
|
||||
index_t CShuffleMRepeatPerShuffle,
|
||||
index_t CShuffleNRepeatPerShuffle,
|
||||
typename CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
index_t CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
index_t NumGemmKPrefetchStage = 1,
|
||||
LoopScheduler LoopSched = make_default_loop_scheduler(),
|
||||
ck::PipelineVersion PipelineVer = ck::PipelineVersion::v1,
|
||||
typename ck::enable_if<NDimSpatial == 3, bool>::type = false>
|
||||
struct DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
: public DeviceGroupedConvBwdWeight<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation>
|
||||
{
|
||||
using DeviceOp = DeviceGroupedConvBwdWeight_Wmma_CShuffle;
|
||||
|
||||
using ADataType = OutDataType;
|
||||
using BDataType = InDataType;
|
||||
using CDataType = WeiDataType;
|
||||
|
||||
using AElementwiseOperation = OutElementwiseOperation;
|
||||
using BElementwiseOperation = InElementwiseOperation;
|
||||
using CElementwiseOperation = WeiElementwiseOperation;
|
||||
|
||||
// TODO make A/B datatype different
|
||||
using ABDataType = InDataType;
|
||||
|
||||
// 3d
|
||||
static constexpr bool is_NDHWGK_GKZYXC_NDHWGC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::NDHWGC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::NDHWGK>;
|
||||
static constexpr bool is_GNDHWK_GKZYXC_GNDHWC =
|
||||
is_same_v<InLayout, tensor_layout::convolution::GNDHWC> &&
|
||||
is_same_v<WeiLayout, tensor_layout::convolution::GKZYXC> &&
|
||||
is_same_v<OutLayout, tensor_layout::convolution::GNDHWK>;
|
||||
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
static constexpr auto I1 = Number<1>{};
|
||||
static constexpr auto I2 = Number<2>{};
|
||||
static constexpr auto I3 = Number<3>{};
|
||||
static constexpr auto I4 = Number<4>{};
|
||||
static constexpr auto I5 = Number<5>{};
|
||||
|
||||
static constexpr auto GemmK1Number = Number<K1>{};
|
||||
static constexpr index_t KPerBlock = K0PerBlock * GemmK1Number;
|
||||
|
||||
template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
constexpr static auto
|
||||
make_out_grid_desc(const index_t N,
|
||||
const index_t Do,
|
||||
const index_t Ho,
|
||||
const index_t Wo,
|
||||
const index_t K,
|
||||
const std::array<index_t, NDimSpatial + 3>& output_strides)
|
||||
{
|
||||
const index_t WoStride = output_strides[5];
|
||||
const auto KStride = Number<1>{};
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Do * Ho * Wo, K),
|
||||
make_tuple(WoStride, KStride));
|
||||
}
|
||||
|
||||
template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
constexpr static auto
|
||||
make_in_grid_desc(const index_t N,
|
||||
const index_t Di,
|
||||
const index_t Hi,
|
||||
const index_t Wi,
|
||||
const index_t C,
|
||||
const std::array<index_t, NDimSpatial + 3>& input_strides)
|
||||
{
|
||||
const index_t NStride = input_strides[1];
|
||||
const index_t DiStride = input_strides[3];
|
||||
const index_t HiStride = input_strides[4];
|
||||
const index_t WiStride = input_strides[5];
|
||||
const auto CStride = input_strides[2];
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
return make_naive_tensor_descriptor(make_tuple(N * Di * Hi * Wi, C),
|
||||
make_tuple(WiStride, CStride));
|
||||
}
|
||||
else
|
||||
{
|
||||
return make_naive_tensor_descriptor(
|
||||
make_tuple(N, Di, Hi, Wi, C),
|
||||
make_tuple(NStride, DiStride, HiStride, WiStride, CStride));
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
constexpr static auto
|
||||
make_wei_grid_desc(const index_t K,
|
||||
const index_t Z,
|
||||
const index_t Y,
|
||||
const index_t X,
|
||||
const index_t C,
|
||||
const std::array<index_t, NDimSpatial + 3>& weights_strides)
|
||||
{
|
||||
const auto CStride = Number<1>{};
|
||||
const auto KStride = weights_strides[1];
|
||||
return make_naive_tensor_descriptor(make_tuple(K, Z * Y * X * C),
|
||||
make_tuple(KStride, CStride));
|
||||
}
|
||||
|
||||
template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(
|
||||
const index_t N,
|
||||
const index_t K,
|
||||
const index_t C,
|
||||
const std::array<index_t, NDimSpatial>& input_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial>& output_spatial_lengths,
|
||||
const std::array<index_t, NDimSpatial + 3>& input_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& weights_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& output_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads)
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
const index_t Di = input_spatial_lengths[0];
|
||||
const index_t Hi = input_spatial_lengths[1];
|
||||
const index_t Wi = input_spatial_lengths[2];
|
||||
|
||||
const index_t Do = output_spatial_lengths[0];
|
||||
const index_t Ho = output_spatial_lengths[1];
|
||||
const index_t Wo = output_spatial_lengths[2];
|
||||
|
||||
const index_t Z = filter_spatial_lengths[0];
|
||||
const index_t Y = filter_spatial_lengths[1];
|
||||
const index_t X = filter_spatial_lengths[2];
|
||||
|
||||
const index_t ConvStrideD = conv_filter_strides[0];
|
||||
const index_t ConvStrideH = conv_filter_strides[1];
|
||||
const index_t ConvStrideW = conv_filter_strides[2];
|
||||
|
||||
const index_t ConvDilationD = conv_filter_dilations[0];
|
||||
const index_t ConvDilationH = conv_filter_dilations[1];
|
||||
const index_t ConvDilationW = conv_filter_dilations[2];
|
||||
|
||||
const index_t InLeftPadD = input_left_pads[0];
|
||||
const index_t InLeftPadH = input_left_pads[1];
|
||||
const index_t InLeftPadW = input_left_pads[2];
|
||||
|
||||
const index_t InRightPadD = input_right_pads[0];
|
||||
const index_t InRightPadH = input_right_pads[1];
|
||||
const index_t InRightPadW = input_right_pads[2];
|
||||
|
||||
const index_t GemmKTotal = N * Do * Ho * Wo;
|
||||
const index_t GemmM = K;
|
||||
const index_t GemmN = C * Z * X * Y;
|
||||
|
||||
const auto PadGemmM = (MPerBlock - GemmM % MPerBlock) % MPerBlock;
|
||||
const auto PadGemmN = (NPerBlock - GemmN % NPerBlock) % NPerBlock;
|
||||
|
||||
const index_t GemmK0 =
|
||||
math::integer_divide_ceil(GemmKTotal, GemmK1Number * K0PerBlock) * K0PerBlock;
|
||||
const index_t GemmKPad = GemmK0 * GemmK1Number;
|
||||
|
||||
const auto out_grid_desc = make_out_grid_desc<NDim>(N, Do, Ho, Wo, K, output_strides);
|
||||
const auto in_grid_desc = make_in_grid_desc<NDim>(N, Di, Hi, Wi, C, input_strides);
|
||||
const auto wei_grid_desc = make_wei_grid_desc<NDim>(K, Z, Y, X, C, weights_strides);
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
wei_grid_desc);
|
||||
}
|
||||
else
|
||||
{
|
||||
// A: output tensor
|
||||
const auto out_gemmkpad_gemmm_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmkpad_gemmm_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmM)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// B: input tensor
|
||||
const auto in_n_dip_hip_wip_c_grid_desc = transform_tensor_descriptor(
|
||||
in_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N),
|
||||
make_pad_transform(Di, InLeftPadD, InRightPadD),
|
||||
make_pad_transform(Hi, InLeftPadH, InRightPadH),
|
||||
make_pad_transform(Wi, InLeftPadW, InRightPadW),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}));
|
||||
|
||||
const auto in_n_z_do_y_ho_x_wo_c_grid_desc = transform_tensor_descriptor(
|
||||
in_n_dip_hip_wip_c_grid_desc,
|
||||
make_tuple(
|
||||
make_pass_through_transform(N),
|
||||
make_embed_transform(make_tuple(Z, Do), make_tuple(ConvDilationD, ConvStrideD)),
|
||||
make_embed_transform(make_tuple(Y, Ho), make_tuple(ConvDilationH, ConvStrideH)),
|
||||
make_embed_transform(make_tuple(X, Wo), make_tuple(ConvDilationW, ConvStrideW)),
|
||||
make_pass_through_transform(C)),
|
||||
make_tuple(
|
||||
Sequence<0>{}, Sequence<1>{}, Sequence<2>{}, Sequence<3>{}, Sequence<4>{}),
|
||||
make_tuple(Sequence<0>{},
|
||||
Sequence<1, 2>{},
|
||||
Sequence<3, 4>{},
|
||||
Sequence<5, 6>{},
|
||||
Sequence<7>{}));
|
||||
|
||||
const auto in_gemmktotal_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_n_z_do_y_ho_x_wo_c_grid_desc,
|
||||
make_tuple(make_merge_transform(make_tuple(Z, Y, X, C)),
|
||||
make_merge_transform(make_tuple(N, Do, Ho, Wo))),
|
||||
make_tuple(Sequence<1, 3, 5, 7>{}, Sequence<0, 2, 4, 6>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0>{}));
|
||||
|
||||
const auto in_gemmkpad_gemmn_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmktotal_gemmn_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmKTotal, GemmKPad - GemmKTotal),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc = transform_tensor_descriptor(
|
||||
in_gemmkpad_gemmn_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(GemmK0, GemmK1Number)),
|
||||
make_pass_through_transform(GemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
|
||||
// Pad
|
||||
const auto out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
out_gemmkbatch_gemmk0_gemmm_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc =
|
||||
transform_tensor_descriptor(
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_grid_desc,
|
||||
make_tuple(make_pass_through_transform(GemmK0),
|
||||
make_right_pad_transform(GemmN, PadGemmN),
|
||||
make_pass_through_transform(GemmK1Number)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}, Sequence<2>{}));
|
||||
|
||||
const auto wei_gemmm_gemmn_pad_grid_desc =
|
||||
transform_tensor_descriptor(wei_grid_desc,
|
||||
make_tuple(make_right_pad_transform(GemmM, PadGemmM),
|
||||
make_right_pad_transform(GemmN, PadGemmN)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}));
|
||||
|
||||
return make_tuple(out_gemmkbatch_gemmk0_gemmm_gemmk1_pad_grid_desc,
|
||||
in_gemmkbatch_gemmk0_gemmn_gemmk1_pad_grid_desc,
|
||||
wei_gemmm_gemmn_pad_grid_desc);
|
||||
}
|
||||
}
|
||||
|
||||
template <index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
|
||||
static auto GetABCGridDesc()
|
||||
{
|
||||
const index_t dim = 1;
|
||||
const std::array<index_t, NDimSpatial> lengths{1, 1, 1};
|
||||
const std::array<index_t, NDimSpatial + 3> strides{1, 1, 1, 1, 1, 1};
|
||||
const std::array<index_t, NDimSpatial> params{1, 1, 1};
|
||||
return MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<3>(dim,
|
||||
dim,
|
||||
dim,
|
||||
lengths,
|
||||
lengths,
|
||||
lengths,
|
||||
strides,
|
||||
strides,
|
||||
strides,
|
||||
params,
|
||||
params,
|
||||
params,
|
||||
params);
|
||||
}
|
||||
|
||||
using ABCGridDescs = decltype(GetABCGridDesc<NDimSpatial>());
|
||||
|
||||
using AGridDesc_K0_M_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I0])>;
|
||||
using BGridDesc_K0_N_K1 = remove_cvref_t<decltype(ABCGridDescs{}[I1])>;
|
||||
using CGridDesc_M_N = remove_cvref_t<decltype(ABCGridDescs{}[I2])>;
|
||||
|
||||
using GridwiseGemm = GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle<
|
||||
// DataType Family
|
||||
ADataType,
|
||||
BDataType,
|
||||
AccDataType,
|
||||
CDataType,
|
||||
Tuple<>,
|
||||
CDataType,
|
||||
// InMemory Data Descriptor
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
Tuple<>,
|
||||
CGridDesc_M_N,
|
||||
// ElementwiseOp Family
|
||||
AElementwiseOperation,
|
||||
BElementwiseOperation,
|
||||
CElementwiseOperation,
|
||||
InMemoryDataOperationEnum::Set,
|
||||
// Tiling Family
|
||||
MPerBlock,
|
||||
NPerBlock,
|
||||
K0PerBlock,
|
||||
MPerWMMA,
|
||||
NPerWMMA,
|
||||
K1,
|
||||
MRepeat,
|
||||
NRepeat,
|
||||
// ThreadCluster Family
|
||||
BlockSize,
|
||||
ABlockTransferThreadClusterLengths_K0_M_K1,
|
||||
ABlockTransferThreadClusterArrangeOrder,
|
||||
ABlockTransferSrcAccessOrder,
|
||||
ABlockTransferSrcVectorDim,
|
||||
ABlockTransferSrcScalarPerVector,
|
||||
ABlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
ABlockLdsAddExtraM,
|
||||
BBlockTransferThreadClusterLengths_K0_N_K1,
|
||||
BBlockTransferThreadClusterArrangeOrder,
|
||||
BBlockTransferSrcAccessOrder,
|
||||
BBlockTransferSrcVectorDim,
|
||||
BBlockTransferSrcScalarPerVector,
|
||||
BBlockTransferDstScalarPerVector_K1,
|
||||
false,
|
||||
BBlockLdsAddExtraN,
|
||||
CShuffleMRepeatPerShuffle,
|
||||
CShuffleNRepeatPerShuffle,
|
||||
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CShuffleBlockTransferScalarPerVector_NPerBlock,
|
||||
NumGemmKPrefetchStage,
|
||||
LoopSched,
|
||||
PipelineVer>;
|
||||
|
||||
using DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(Tuple<>{}));
|
||||
|
||||
using CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock =
|
||||
decltype(GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
CGridDesc_M_N{}));
|
||||
|
||||
using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(
|
||||
CGridDesc_M_N{}, I1 /* M01 */, I1 /* N01 */));
|
||||
|
||||
struct Argument : public BaseArgument
|
||||
{
|
||||
Argument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
index_t split_k)
|
||||
: p_a_grid_{p_out_grid},
|
||||
p_b_grid_{p_in_grid},
|
||||
p_c_grid_{p_wei_grid},
|
||||
a_grid_desc_kbatch_k0_m_k1_{},
|
||||
b_grid_desc_kbatch_k0_n_k1_{},
|
||||
c_grid_desc_m_n_{},
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_{},
|
||||
block_2_ctile_map_{},
|
||||
compute_ptr_offset_of_batch_{},
|
||||
a_element_op_{out_element_op},
|
||||
b_element_op_{in_element_op},
|
||||
c_element_op_{wei_element_op},
|
||||
Conv_G_{a_g_n_c_wis_lengths[0]},
|
||||
Conv_N_{a_g_n_c_wis_lengths[1]},
|
||||
Conv_K_{b_g_k_c_xs_lengths[1]},
|
||||
Conv_C_{a_g_n_c_wis_lengths[2]},
|
||||
input_spatial_lengths_{},
|
||||
filter_spatial_lengths_{},
|
||||
output_spatial_lengths_{},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
constexpr index_t spatial_offset = 3;
|
||||
std::copy(begin(a_g_n_c_wis_lengths) + spatial_offset,
|
||||
end(a_g_n_c_wis_lengths),
|
||||
begin(input_spatial_lengths_));
|
||||
std::copy(begin(b_g_k_c_xs_lengths) + spatial_offset,
|
||||
end(b_g_k_c_xs_lengths),
|
||||
begin(filter_spatial_lengths_));
|
||||
std::copy(begin(e_g_n_k_wos_lengths) + spatial_offset,
|
||||
end(e_g_n_k_wos_lengths),
|
||||
begin(output_spatial_lengths_));
|
||||
|
||||
const auto descs =
|
||||
DeviceOp::MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N<NDimSpatial>(
|
||||
Conv_N_,
|
||||
Conv_K_,
|
||||
Conv_C_,
|
||||
input_spatial_lengths_,
|
||||
filter_spatial_lengths_,
|
||||
output_spatial_lengths_,
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads);
|
||||
|
||||
a_grid_desc_kbatch_k0_m_k1_ = descs[I0];
|
||||
b_grid_desc_kbatch_k0_n_k1_ = descs[I1];
|
||||
c_grid_desc_m_n_ = descs[I2];
|
||||
|
||||
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(
|
||||
c_grid_desc_m_n_, I1 /* M01 */, I1 /* N01 */);
|
||||
|
||||
// A/B/C Batch Stride
|
||||
compute_ptr_offset_of_batch_.BatchStrideA_ = e_g_n_k_wos_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideB_ = a_g_n_c_wis_strides[0];
|
||||
compute_ptr_offset_of_batch_.BatchStrideE_ =
|
||||
Conv_K_ * Conv_C_ *
|
||||
std::accumulate(begin(filter_spatial_lengths_),
|
||||
end(filter_spatial_lengths_),
|
||||
index_t{1},
|
||||
std::multiplies<>{});
|
||||
|
||||
if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_,
|
||||
b_grid_desc_kbatch_k0_n_k1_,
|
||||
c_grid_desc_m_n_,
|
||||
block_2_ctile_map_))
|
||||
{
|
||||
c_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
c_grid_desc_m_n_);
|
||||
}
|
||||
}
|
||||
|
||||
const ADataType* p_a_grid_;
|
||||
const BDataType* p_b_grid_;
|
||||
CDataType* p_c_grid_;
|
||||
AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_;
|
||||
BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_;
|
||||
CGridDesc_M_N c_grid_desc_m_n_;
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock c_grid_desc_mblock_mperblock_nblock_nperblock_;
|
||||
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
|
||||
|
||||
OutElementwiseOperation a_element_op_;
|
||||
InElementwiseOperation b_element_op_;
|
||||
WeiElementwiseOperation c_element_op_;
|
||||
|
||||
// for checking IsSupportedArgument()
|
||||
const index_t Conv_G_;
|
||||
const index_t Conv_N_;
|
||||
const index_t Conv_K_;
|
||||
const index_t Conv_C_;
|
||||
std::array<index_t, NDimSpatial> input_spatial_lengths_;
|
||||
std::array<index_t, NDimSpatial> filter_spatial_lengths_;
|
||||
std::array<index_t, NDimSpatial> output_spatial_lengths_;
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides_;
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads_;
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads_;
|
||||
const index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
struct Invoker : public BaseInvoker
|
||||
{
|
||||
using Argument = DeviceOp::Argument;
|
||||
|
||||
void Print(const Argument& arg)
|
||||
{
|
||||
std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{"
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", "
|
||||
<< arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{"
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", "
|
||||
<< arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << "}" << std::endl;
|
||||
|
||||
std::cout << "arg.c_grid_desc_m_n_{" << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
|
||||
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
|
||||
}
|
||||
|
||||
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
if(stream_config.log_level_ > 0)
|
||||
{
|
||||
Print(arg);
|
||||
}
|
||||
|
||||
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"wrong! GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle has invalid "
|
||||
"setting");
|
||||
}
|
||||
|
||||
const index_t grid_size =
|
||||
arg.block_2_ctile_map_.CalculateGridSize(arg.c_grid_desc_m_n_) * arg.Conv_G_;
|
||||
|
||||
const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1);
|
||||
|
||||
const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
|
||||
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle<
|
||||
GridwiseGemm,
|
||||
ADataType,
|
||||
BDataType,
|
||||
typename GridwiseGemm::DsGridPointer,
|
||||
CDataType,
|
||||
OutElementwiseOperation,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
AGridDesc_K0_M_K1,
|
||||
BGridDesc_K0_N_K1,
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
|
||||
remove_reference_t<typename GridwiseGemm::DefaultBlock2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch<I0>,
|
||||
has_main_loop>;
|
||||
|
||||
using EmptyTuple = Tuple<>;
|
||||
return launch_and_time_kernel(stream_config,
|
||||
kernel,
|
||||
dim3(grid_size),
|
||||
dim3(BlockSize),
|
||||
0,
|
||||
arg.p_a_grid_,
|
||||
arg.p_b_grid_,
|
||||
EmptyTuple{}, // Ds
|
||||
arg.p_c_grid_,
|
||||
arg.a_element_op_,
|
||||
arg.b_element_op_,
|
||||
arg.c_element_op_,
|
||||
arg.Conv_G_,
|
||||
arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock{},
|
||||
arg.c_grid_desc_mblock_mperblock_nblock_nperblock_,
|
||||
arg.block_2_ctile_map_,
|
||||
arg.compute_ptr_offset_of_batch_);
|
||||
};
|
||||
|
||||
if(has_main_k0_block_loop)
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, true>{});
|
||||
}
|
||||
else
|
||||
{
|
||||
return launch_kernel(integral_constant<bool, false>{});
|
||||
}
|
||||
}
|
||||
|
||||
float Run(const BaseArgument* p_arg,
|
||||
const StreamConfig& stream_config = StreamConfig{}) override
|
||||
{
|
||||
return Run(*dynamic_cast<const Argument*>(p_arg), stream_config);
|
||||
}
|
||||
};
|
||||
|
||||
static constexpr bool IsValidCompilationParameter()
|
||||
{
|
||||
// TODO: properly implement this check
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
// check device
|
||||
if(get_device_name() == "gfx1100" || get_device_name() == "gfx1101" ||
|
||||
get_device_name() == "gfx1102")
|
||||
{
|
||||
if constexpr(!(is_same_v<AccDataType, float> || is_same_v<AccDataType, int32_t>))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: Add support for split_k > 1
|
||||
if(arg.k_batch_ != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(!(is_NDHWGK_GKZYXC_NDHWGC || is_GNDHWK_GKZYXC_GNDHWC))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if constexpr(ConvBackwardWeightSpecialization ==
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0)
|
||||
{
|
||||
// check if it's a 1x1 convolution with stride=1 and no padding
|
||||
for(int i = 0; i < NDimSpatial; i++)
|
||||
{
|
||||
if(!(arg.filter_spatial_lengths_[i] == 1 && arg.conv_filter_strides_[i] == 1 &&
|
||||
arg.input_left_pads_[i] == 0 && arg.input_right_pads_[i] == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// vector load A/B matrix from global memory
|
||||
if(!(ABlockTransferSrcVectorDim == 1 && BBlockTransferSrcVectorDim == 1 &&
|
||||
arg.Conv_K_ % ABlockTransferSrcScalarPerVector == 0 &&
|
||||
arg.Conv_C_ % BBlockTransferSrcScalarPerVector == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// vector store C matrix into global memory
|
||||
if(!(arg.Conv_C_ % CShuffleBlockTransferScalarPerVector_NPerBlock == 0))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Gridwise GEMM size
|
||||
return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_,
|
||||
arg.b_grid_desc_kbatch_k0_n_k1_,
|
||||
arg.c_grid_desc_m_n_,
|
||||
arg.block_2_ctile_map_);
|
||||
}
|
||||
|
||||
bool IsSupportedArgument(const BaseArgument* p_arg) override
|
||||
{
|
||||
return IsSupportedArgument(*dynamic_cast<const Argument*>(p_arg));
|
||||
}
|
||||
|
||||
static auto
|
||||
MakeArgument(const InDataType* p_in_grid,
|
||||
WeiDataType* p_wei_grid,
|
||||
const OutDataType* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
const index_t split_k)
|
||||
{
|
||||
return Argument{p_in_grid,
|
||||
p_wei_grid,
|
||||
p_out_grid,
|
||||
a_g_n_c_wis_lengths, // input
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths, // weight
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths, // output
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
split_k};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
|
||||
std::unique_ptr<BaseArgument>
|
||||
MakeArgumentPointer(const void* p_in_grid,
|
||||
void* p_wei_grid,
|
||||
const void* p_out_grid,
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_lengths, // input
|
||||
const std::array<index_t, NDimSpatial + 3>& a_g_n_c_wis_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_lengths, // weight
|
||||
const std::array<index_t, NDimSpatial + 3>& b_g_k_c_xs_strides,
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_lengths, // output
|
||||
const std::array<index_t, NDimSpatial + 3>& e_g_n_k_wos_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_strides,
|
||||
const std::array<index_t, NDimSpatial>& conv_filter_dilations,
|
||||
const std::array<index_t, NDimSpatial>& input_left_pads,
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
InElementwiseOperation in_element_op,
|
||||
WeiElementwiseOperation wei_element_op,
|
||||
OutElementwiseOperation out_element_op,
|
||||
const index_t split_k) override
|
||||
{
|
||||
return std::make_unique<Argument>(static_cast<const InDataType*>(p_in_grid),
|
||||
static_cast<WeiDataType*>(p_wei_grid),
|
||||
static_cast<const OutDataType*>(p_out_grid),
|
||||
a_g_n_c_wis_lengths, // input
|
||||
a_g_n_c_wis_strides,
|
||||
b_g_k_c_xs_lengths, // weight
|
||||
b_g_k_c_xs_strides,
|
||||
e_g_n_k_wos_lengths, // output
|
||||
e_g_n_k_wos_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
in_element_op,
|
||||
wei_element_op,
|
||||
out_element_op,
|
||||
split_k);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
{
|
||||
return std::make_unique<Invoker>(Invoker{});
|
||||
}
|
||||
|
||||
std::string GetTypeString() const override
|
||||
{
|
||||
auto str = std::stringstream();
|
||||
|
||||
// clang-format off
|
||||
str << "DeviceGroupedConvBwdWeight_Wmma_CShuffle"
|
||||
<< "<"
|
||||
<< BlockSize << ", "
|
||||
<< MPerBlock << ", "
|
||||
<< NPerBlock << ", "
|
||||
<< K0PerBlock << ", "
|
||||
<< getConvBackwardWeightSpecializationString(ConvBackwardWeightSpecialization) << ", "
|
||||
<< K1 << ", "
|
||||
<< ABlockTransferSrcScalarPerVector << ", "
|
||||
<< ABlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< BBlockTransferSrcScalarPerVector << ", "
|
||||
<< BBlockTransferDstScalarPerVector_K1 << ", "
|
||||
<< CShuffleMRepeatPerShuffle << ", "
|
||||
<< CShuffleNRepeatPerShuffle << ", "
|
||||
<< CShuffleBlockTransferScalarPerVector_NPerBlock
|
||||
<< ">";
|
||||
// clang-format on
|
||||
|
||||
return str.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/device_grouped_conv_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_bwd_weight.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
#include "ck/host_utility/kernel_launch.hpp"
|
||||
|
||||
@@ -21,32 +22,6 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
namespace {
|
||||
|
||||
struct ComputePtrOffsetOfStridedBatch
|
||||
{
|
||||
__host__ __device__ constexpr long_index_t GetAPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideA_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetBPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideB_);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideC_);
|
||||
}
|
||||
|
||||
index_t BatchStrideA_;
|
||||
index_t BatchStrideB_;
|
||||
index_t BatchStrideC_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename GridwiseGemm,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
@@ -1222,7 +1197,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
Block2CTileMap block_2_ctile_map_;
|
||||
|
||||
// for computing batch offset
|
||||
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
|
||||
ComputePtrOffsetOfStridedBatch<I0> compute_ptr_offset_of_batch_;
|
||||
|
||||
index_t M01_;
|
||||
index_t N01_;
|
||||
@@ -1301,7 +1276,7 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
|
||||
remove_reference_t<DeviceOp::CGridDesc_MBlock_MPerBlock_NBlock_NPerBlock>,
|
||||
remove_reference_t<DeviceOp::Block2CTileMap>,
|
||||
ComputePtrOffsetOfStridedBatch,
|
||||
ComputePtrOffsetOfStridedBatch<I0>,
|
||||
has_main_loop>;
|
||||
|
||||
return launch_and_time_kernel(stream_config,
|
||||
@@ -1348,6 +1323,10 @@ struct DeviceGroupedConvBwdWeight_Xdl_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(!ck::is_xdl_supported())
|
||||
{
|
||||
return false;
|
||||
}
|
||||
if constexpr(NDimSpatial == 1)
|
||||
{
|
||||
if constexpr(!is_GNWK_GKXC_GNWC)
|
||||
|
||||
@@ -471,7 +471,7 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
auto launch_kernel = [&](auto has_main_k_block_loop) {
|
||||
constexpr bool has_main_loop = has_main_k_block_loop.value;
|
||||
|
||||
const auto kernel = kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle<
|
||||
const auto kernel = kernel_grouped_conv_multiple_d_wmma_cshuffle<
|
||||
GridwiseOp,
|
||||
ADataType,
|
||||
BDataType,
|
||||
|
||||
@@ -43,7 +43,13 @@ struct ComputePtrOffsetOfStridedBatch
|
||||
return ds_offset;
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetEPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
|
||||
// alias for kernels without multiple D
|
||||
[[maybe_unused]] __host__ __device__ constexpr long_index_t GetCPtrOffset(index_t g_idx) const
|
||||
{
|
||||
return g_idx * static_cast<long_index_t>(BatchStrideE_);
|
||||
}
|
||||
@@ -52,6 +58,7 @@ struct ComputePtrOffsetOfStridedBatch
|
||||
index_t BatchStrideB_;
|
||||
Array<ck::index_t, NumDTensor> BatchStrideDs_;
|
||||
index_t BatchStrideE_;
|
||||
index_t& BatchStrideC_ = BatchStrideE_; // alias for kernels without multiple D
|
||||
};
|
||||
|
||||
} // namespace device
|
||||
|
||||
@@ -36,7 +36,7 @@ __global__ void
|
||||
#if CK_USE_LAUNCH_BOUNDS
|
||||
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
|
||||
#endif
|
||||
kernel_grouped_conv_fwd_multiple_d_wmma_cshuffle(
|
||||
kernel_grouped_conv_multiple_d_wmma_cshuffle(
|
||||
const ADataType* __restrict__ p_a_grid,
|
||||
const BDataType* __restrict__ p_b_grid,
|
||||
DsPointer p_ds_grid,
|
||||
@@ -452,11 +452,11 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
|
||||
}
|
||||
|
||||
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
|
||||
// CheckValidity for kernels without multi D
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const DsGridDesc_M_N& ds_grid_desc_m_n,
|
||||
const EGridDesc_M_N& e_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
@@ -471,18 +471,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
const auto K0 = a_grid_desc_k0_m_k1.GetLength(I0);
|
||||
|
||||
bool valid = true;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
|
||||
N == ds_grid_desc_m_n[i].GetLength(I1));
|
||||
});
|
||||
|
||||
if(!valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if(!(M == e_grid_desc_m_n.GetLength(I0) && N == e_grid_desc_m_n.GetLength(I1) &&
|
||||
K0 == b_grid_desc_k0_n_k1.GetLength(I0) && K1 == a_grid_desc_k0_m_k1.GetLength(I2) &&
|
||||
K1 == b_grid_desc_k0_n_k1.GetLength(I2)))
|
||||
@@ -517,6 +505,31 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename Block2CTileMap>
|
||||
__host__ __device__ static constexpr bool
|
||||
CheckValidity(const AGridDesc_K0_M_K1& a_grid_desc_k0_m_k1,
|
||||
const BGridDesc_K0_N_K1& b_grid_desc_k0_n_k1,
|
||||
const DsGridDesc_M_N& ds_grid_desc_m_n,
|
||||
const EGridDesc_M_N& e_grid_desc_m_n,
|
||||
const Block2CTileMap& block_2_ctile_map)
|
||||
{
|
||||
const auto M = a_grid_desc_k0_m_k1.GetLength(I1);
|
||||
const auto N = b_grid_desc_k0_n_k1.GetLength(I1);
|
||||
bool valid = true;
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
valid = valid && (M == ds_grid_desc_m_n[i].GetLength(I0) &&
|
||||
N == ds_grid_desc_m_n[i].GetLength(I1));
|
||||
});
|
||||
|
||||
if(!valid)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
return CheckValidity(
|
||||
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, e_grid_desc_m_n, block_2_ctile_map);
|
||||
}
|
||||
|
||||
__host__ __device__ static constexpr bool CalculateHasMainKBlockLoop(index_t K)
|
||||
{
|
||||
const index_t num_loop = K / (K0PerBlock * K1);
|
||||
|
||||
@@ -6,8 +6,6 @@
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_dl.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <index_t NDSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_f16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
// blocksize=256
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 8, 8, 16, 16, 8, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 256, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 16>, 4>,
|
||||
// blocksize=128
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 256, 8, 8, 16, 16, 1, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 32, 8, 8, 16, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// blocksize=64
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 8, 8, 16, 16, 4, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 8, 8, 16, 16, 2, 4, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 8, 8, 16, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 8, 8, 16, 16, 1, 8, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
// blocksize=32
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 32, 8, 8, 16, 16, 1, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 64, 8, 8, 16, 16, 1, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 64, 8, 8, 16, 16, 2, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 32, 8, 8, 16, 16, 2, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 32, 8, 8, 16, 16, 4, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 16, 8, 8, 16, 16, 4, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 16, 8, 8, 16, 16, 2, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_i8_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
// blocksize=256
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 256, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
// blocksize=128
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 256, 8, 8, 16, 16, 4, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 256, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 256, 8, 8, 16, 16, 1, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 32, 8, 8, 16, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 64, 8, 8, 16, 16, 8, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 128, 8, 8, 16, 16, 8, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// blocksize=64
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 8, 8, 16, 16, 1, 8, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 128, 8, 8, 16, 16, 2, 8, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 64, 8, 8, 16, 16, 8, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 8, 8, 16, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
// blocksize=32
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 64, 8, 8, 16, 16, 1, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 64, 8, 8, 16, 16, 4, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 32, 8, 8, 16, 16, 2, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 16, 8, 8, 16, 16, 4, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -163,6 +163,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances(
|
||||
@@ -177,6 +201,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
@@ -202,6 +251,30 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
|
||||
@@ -231,6 +304,31 @@ void add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_
|
||||
BF8,
|
||||
F8>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef DL_KERNELS
|
||||
// dl
|
||||
@@ -694,6 +792,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -708,6 +810,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
else if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
@@ -737,6 +849,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
#endif
|
||||
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -752,6 +868,16 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
else if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> &&
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
set(GROUPED_CONV1D_BWD_WEIGHT
|
||||
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp
|
||||
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp
|
||||
device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp)
|
||||
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_f32_instance.cpp
|
||||
xdl/device_grouped_conv1d_bwd_weight_xdl_gnwc_gkxc_gnwk_bf16_instance.cpp)
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND GROUPED_CONV1D_BWD_WEIGHT
|
||||
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp
|
||||
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp
|
||||
device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp
|
||||
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp
|
||||
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp
|
||||
device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp)
|
||||
dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f16_instance.cpp
|
||||
dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_f32_instance.cpp
|
||||
dl/device_grouped_conv1d_bwd_weight_dl_gnwc_gkxc_gnwk_bf16_instance.cpp
|
||||
dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f16_instance.cpp
|
||||
dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_f32_instance.cpp
|
||||
dl/device_grouped_conv1d_bwd_weight_dl_nwgc_gkxc_nwgk_bf16_instance.cpp)
|
||||
endif()
|
||||
|
||||
add_instance_library(device_grouped_conv1d_bwd_weight_instance ${GROUPED_CONV1D_BWD_WEIGHT})
|
||||
|
||||
@@ -1,19 +1,19 @@
|
||||
set(GROUPED_CONV2D_BWD_WEIGHT
|
||||
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp)
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp)
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND GROUPED_CONV2D_BWD_WEIGHT
|
||||
device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp)
|
||||
dl/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f16_instance.cpp
|
||||
dl/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_f32_instance.cpp
|
||||
dl/device_grouped_conv2d_bwd_weight_dl_gnhwc_gkyxc_gnhwk_bf16_instance.cpp
|
||||
dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
dl/device_grouped_conv2d_bwd_weight_dl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp)
|
||||
endif()
|
||||
|
||||
add_instance_library(device_grouped_conv2d_bwd_weight_instance ${GROUPED_CONV2D_BWD_WEIGHT})
|
||||
|
||||
@@ -1,20 +1,30 @@
|
||||
set(GROUPED_CONV3D_BWD_WEIGHT
|
||||
device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f16_comp_bf8_fp8_instance.cpp)
|
||||
|
||||
if(DL_KERNELS)
|
||||
list(APPEND GROUPED_CONV3D_BWD_WEIGHT
|
||||
device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp)
|
||||
dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_f32_instance.cpp
|
||||
dl/device_grouped_conv3d_bwd_weight_dl_gndhwc_gkzyxc_gndhwk_bf16_instance.cpp
|
||||
dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
dl/device_grouped_conv3d_bwd_weight_dl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp)
|
||||
endif()
|
||||
|
||||
list(APPEND GROUPED_CONV3D_BWD_WEIGHT
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instance.cpp)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_bwd_weight_instance ${GROUPED_CONV3D_BWD_WEIGHT})
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_f16_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_f16_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_i8_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_i8_instances<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
ConvBwdWeightDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_i8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,35 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_weight/device_grouped_conv_bwd_weight_wmma_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_weight_wmma_i8_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
ConvBwdWeightDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -147,7 +147,9 @@ GB/s: 127.947
|
||||
# arg1: tensor operation (grouped_conv_bwd_weight: Grouped Convolution Backward Weight)
|
||||
# arg2: data type (0: Input fp32, Weight fp32, Output fp32
|
||||
# 1: Input fp16, Weight fp16, Output fp16
|
||||
# 2: Input bf16, Weight fp32, Output bf16)
|
||||
# 2: Input bf16, Weight fp32, Output bf16
|
||||
# 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8
|
||||
# 4: Input int8, Weight int8, Output int8)
|
||||
# arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, N, K, Ho, Wo]
|
||||
# 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]
|
||||
# 2: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]
|
||||
@@ -167,7 +169,7 @@ GB/s: 127.947
|
||||
# SplitK
|
||||
|
||||
################ op datatype layout verify init log time Ndims G N K C Y X Hi Wi Sy Sx Dy Dx LeftPy LeftPx RightPy RightPx SplitK
|
||||
./bin/ckProfiler grouped_conv_bwd_weight 1 0 1 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1
|
||||
./bin/ckProfiler grouped_conv_bwd_weight 1 1 0 1 0 1 2 32 256 256 512 3 3 28 28 1 1 1 1 1 0 0 0 1
|
||||
|
||||
```
|
||||
|
||||
|
||||
@@ -20,10 +20,11 @@ enum struct ConvLayout
|
||||
|
||||
enum struct ConvDataType
|
||||
{
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_F32_BF16, // 2
|
||||
F16_F16_F16_BF8_F8 // 3
|
||||
F32_F32_F32, // 0
|
||||
F16_F16_F16, // 1
|
||||
BF16_F32_BF16, // 2
|
||||
F16_F16_F16_BF8_F8, // 3
|
||||
I8_I8_I8 // 4
|
||||
};
|
||||
|
||||
#define OP_NAME "grouped_conv_bwd_weight"
|
||||
@@ -35,7 +36,8 @@ static void print_helper_msg()
|
||||
<< "arg2: data type (0: Input fp32, Weight fp32, Output fp32\n"
|
||||
<< " 1: Input fp16, Weight fp16, Output fp16\n"
|
||||
<< " 2: Input bf16, Weight fp32, Output bf16\n"
|
||||
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8)\n"
|
||||
<< " 3: Input fp16, Weight fp16, Output fp16, Gemm bf8@fp8\n"
|
||||
<< " 4: Input int8, Weight int8, Output int8)\n"
|
||||
<< "arg3: tensor layout (0: Input[G, N, C, Hi, Wi], Weight[G, K, C, Y, X], Output[G, "
|
||||
"N, K, Ho, Wo]\n"
|
||||
<< " 1: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, "
|
||||
@@ -196,6 +198,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
// fp32 atomic add is used for weight tensor in bf16 kernel
|
||||
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, BF16{}, F32{}, BF16{}, BF16{}, BF16{});
|
||||
}
|
||||
else if(data_type == ConvDataType::I8_I8_I8)
|
||||
{
|
||||
return profile(
|
||||
I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
|
||||
}
|
||||
}
|
||||
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
|
||||
{
|
||||
@@ -216,6 +223,11 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
|
||||
{
|
||||
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F16{}, F16{}, F16{}, BF8{}, F8{});
|
||||
}
|
||||
else if(data_type == ConvDataType::I8_I8_I8)
|
||||
{
|
||||
return profile(
|
||||
I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, int8_t{}, int8_t{}, int8_t{}, int8_t{}, int8_t{});
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "this data_type & layout is not implemented" << std::endl;
|
||||
|
||||
@@ -1,11 +1,20 @@
|
||||
list(APPEND gpu_list gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_xdl gfx908 gfx90a gfx940 gfx941 gfx942)
|
||||
list(APPEND gpu_list_wmma gfx1100 gfx1101 gfx1102)
|
||||
|
||||
set(target 0)
|
||||
foreach(gpu IN LISTS GPU_TARGETS)
|
||||
if(gpu IN_LIST gpu_list AND target EQUAL 0)
|
||||
if(gpu IN_LIST gpu_list_xdl AND target EQUAL 0)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_xdl.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
|
||||
set(target 1)
|
||||
endif()
|
||||
if(gpu IN_LIST gpu_list_wmma AND target EQUAL 0)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight test_grouped_convnd_bwd_weight.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight PRIVATE utility device_grouped_conv1d_bwd_weight_instance device_grouped_conv2d_bwd_weight_instance device_grouped_conv3d_bwd_weight_instance)
|
||||
add_gtest_executable(test_grouped_convnd_bwd_weight_interface test_grouped_convnd_bwd_weight_interface_wmma.cpp)
|
||||
target_link_libraries(test_grouped_convnd_bwd_weight_interface PRIVATE utility)
|
||||
set(target 1)
|
||||
endif()
|
||||
endforeach()
|
||||
@@ -11,6 +11,7 @@
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/host_utility/device_prop.hpp"
|
||||
|
||||
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
|
||||
|
||||
@@ -33,8 +34,9 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
|
||||
bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k)
|
||||
{
|
||||
// Odd K or C values are supported only by DL kernel (only applies to fp16)
|
||||
// DL kernel currently supports only `split_k=1`
|
||||
// Odd K or C values are supported only by DL and WMMA
|
||||
// kernels (only applies to fp16)
|
||||
// DL and WMMA kernels currently support only `split_k=1`
|
||||
if constexpr(std::is_same_v<InDataType, ck::half_t>)
|
||||
{
|
||||
if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0))
|
||||
@@ -53,6 +55,42 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
}
|
||||
}
|
||||
|
||||
const bool is_navi3x = ck::get_device_name() == "gfx1100" ||
|
||||
ck::get_device_name() == "gfx1101" ||
|
||||
ck::get_device_name() == "gfx1102";
|
||||
if(is_navi3x)
|
||||
{
|
||||
// on navi3x only support for 3d is implemented
|
||||
if constexpr(NDimSpatial{} != 3)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
// on navi3x only support for i8 and fp16 is implemented
|
||||
if constexpr(!((std::is_same_v<InDataType, int8_t> &&
|
||||
std::is_same_v<WeiDataType, int8_t> &&
|
||||
std::is_same_v<OutDataType, int8_t>) ||
|
||||
(std::is_same_v<InDataType, ck::half_t> &&
|
||||
std::is_same_v<WeiDataType, ck::half_t> &&
|
||||
std::is_same_v<OutDataType, ck::half_t>)))
|
||||
{
|
||||
return true;
|
||||
}
|
||||
// WMMA kernel is only supported for split_k=1
|
||||
if(split_k != 1)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
// support for i8 is only implemented on navi3x
|
||||
if constexpr(std::is_same_v<InDataType, int8_t> &&
|
||||
std::is_same_v<WeiDataType, int8_t> && std::is_same_v<OutDataType, int8_t>)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
@@ -120,9 +158,11 @@ using KernelTypes3d = ::testing::Types<
|
||||
std::tuple<float, float, float, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
|
||||
std::tuple<int8_t, int8_t, int8_t, GNDHWC, GKZYXC, GNDHWK, ck::Number<3>>,
|
||||
std::tuple<float, float, float, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
|
||||
std::tuple<ck::half_t, ck::half_t, ck::half_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>>;
|
||||
std::tuple<ck::bhalf_t, float, ck::bhalf_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>,
|
||||
std::tuple<int8_t, int8_t, int8_t, NDHWGC, GKZYXC, NDHWGK, ck::Number<3>>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight1d, KernelTypes1d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeight2d, KernelTypes2d);
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/algorithm.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
using ConvolutionBackwardWeightSpecialization =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault = ConvolutionBackwardWeightSpecialization::Default;
|
||||
static constexpr auto Filter1x1Stride1Pad0 =
|
||||
ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <typename Tuple, ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
class TestGroupedConvndBwdWeight : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using OutLayout = std::tuple_element_t<0, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<1, Tuple>;
|
||||
using InLayout = std::tuple_element_t<2, Tuple>;
|
||||
static constexpr ck::index_t NDimSpatial = std::tuple_element_t<3, Tuple>{};
|
||||
|
||||
// clang-format off
|
||||
using GroupedConvBwdWeightDeviceInstance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle
|
||||
//| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
//| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
//| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
<NDimSpatial, InLayout, WeiLayout, OutLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>;
|
||||
// clang-format on
|
||||
|
||||
ck::utils::conv::ConvParam conv_param;
|
||||
|
||||
template <ck::index_t SplitK>
|
||||
bool Run()
|
||||
{
|
||||
|
||||
const auto in_g_n_c_wis_desc =
|
||||
ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<InLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto wei_g_k_c_xs_desc =
|
||||
ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<WeiLayout>(
|
||||
conv_param);
|
||||
|
||||
const auto out_g_n_k_wos_desc =
|
||||
ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<OutLayout>(
|
||||
conv_param);
|
||||
|
||||
std::array<ck::index_t, NDimSpatial + 3> input_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> filter_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> output_lengths{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> input_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> weights_strides{};
|
||||
std::array<ck::index_t, NDimSpatial + 3> output_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_strides{};
|
||||
std::array<ck::index_t, NDimSpatial> conv_filter_dilations{};
|
||||
std::array<ck::index_t, NDimSpatial> input_left_pads{};
|
||||
std::array<ck::index_t, NDimSpatial> input_right_pads{};
|
||||
|
||||
auto range_copy = [](const auto& from, auto to) { std::copy(begin(from), end(from), to); };
|
||||
|
||||
range_copy(in_g_n_c_wis_desc.GetLengths(), begin(input_lengths));
|
||||
range_copy(in_g_n_c_wis_desc.GetStrides(), begin(input_strides));
|
||||
range_copy(wei_g_k_c_xs_desc.GetLengths(), begin(filter_lengths));
|
||||
range_copy(wei_g_k_c_xs_desc.GetStrides(), begin(weights_strides));
|
||||
range_copy(out_g_n_k_wos_desc.GetLengths(), begin(output_lengths));
|
||||
range_copy(out_g_n_k_wos_desc.GetStrides(), begin(output_strides));
|
||||
range_copy(conv_param.conv_filter_strides_, begin(conv_filter_strides));
|
||||
range_copy(conv_param.conv_filter_dilations_, begin(conv_filter_dilations));
|
||||
range_copy(conv_param.input_left_pads_, begin(input_left_pads));
|
||||
range_copy(conv_param.input_right_pads_, begin(input_right_pads));
|
||||
|
||||
auto conv = GroupedConvBwdWeightDeviceInstance{};
|
||||
|
||||
auto argument = conv.MakeArgument(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
PassThrough{},
|
||||
SplitK);
|
||||
return conv.IsSupportedArgument(argument);
|
||||
}
|
||||
};
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<GNDHWK, GKZYXC, GNDHWC, ck::Number<3>>,
|
||||
std::tuple<NDHWGK, GKZYXC, NDHWGC, ck::Number<3>>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeightFilter1x13d
|
||||
: public TestGroupedConvndBwdWeight<Tuple, Filter1x1Stride1Pad0>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdWeightDefault3d
|
||||
: public TestGroupedConvndBwdWeight<Tuple, ConvBwdWeightDefault>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeightFilter1x13d, KernelTypes3d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdWeightDefault3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeightFilter1x13d, SpecializationCheck)
|
||||
{
|
||||
// Check filter 3x3x3 instead of 1x1x1
|
||||
this->conv_param = {
|
||||
3, 2, 4, 192, 192, {3, 3, 3}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
bool is_supported = this->template Run<1>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Check strides 2x2x2 instead of 1x1x1
|
||||
this->conv_param = {
|
||||
3, 2, 4, 192, 192, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
is_supported = this->template Run<1>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Check with pad
|
||||
this->conv_param = {
|
||||
3, 2, 4, 192, 192, {1, 1, 1}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}};
|
||||
is_supported = this->template Run<1>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
|
||||
// Supported version
|
||||
this->conv_param = {
|
||||
3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
is_supported = this->template Run<1>();
|
||||
EXPECT_TRUE(is_supported);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeightDefault3d, VectorLoadCheck)
|
||||
{
|
||||
// vector load for A
|
||||
this->conv_param = {
|
||||
3, 2, 128, 129, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
bool is_supported = this->template Run<1>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
// vector load for B, E, Ds
|
||||
this->conv_param = {
|
||||
3, 2, 128, 128, 257, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
is_supported = this->template Run<1>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdWeightDefault3d, SplitKCheck)
|
||||
{
|
||||
// SplitK=1
|
||||
this->conv_param = {
|
||||
3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
bool is_supported = this->template Run<1>();
|
||||
EXPECT_TRUE(is_supported);
|
||||
// SplitK=2
|
||||
this->conv_param = {
|
||||
3, 2, 128, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
is_supported = this->template Run<2>();
|
||||
EXPECT_FALSE(is_supported);
|
||||
}
|
||||
Reference in New Issue
Block a user