mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
Wmma support for grouped convolution bwd weight (#2947)
* Convolution bwd weight device implementation
* Merge branch 'grouped_conv_bwd_weight_device_impl_wmma' into 'feature/conv_bwd_weight_wmma'
Convolution bwd weight device implementation
See merge request amd/ai/composable_kernel!38
* Fix bug and disable splitK=-1 tests for wmma
* Add generic instances for bf16 f32 bf16
* check gridwise level validity in device impl for 1 stage D0
* Fix bugs in device implementation:
- rdna3 compilation error
- gridwise layouts (need to be correct to ensure that CheckValidaity()
works correctly)
* Add padding in conv to gemm transformers for 1x1Stride1Pad0 specialization
* Remove workaround for 1x1Stride1Pad0 conv specialization
* Add instances for xdl parity (for pipeline v1)
* Add two stage instances (xdl parity)
* Add multiple Ds instances
* Add examples
* Uncomment scale instances
* Fix copyright
* Fix examples compilation
* Add atomic add float4
* Fix compilation error
* Fix instances
* Compute tolerances in examples instead of using default ones
* Compute tolerances instead of using default ones in bilinear and scale tests
* Merge branch 'grouped_conv_bwd_weight_instances_examples' into 'feature/conv_bwd_weight_wmma'
Grouped conv: Instances and example bwd weight
See merge request amd/ai/composable_kernel!47
* Device implementation of explicit gemm for grouped conv bwd weight
Based on batched gemm multiple D
* Add instances for pipeline v1 and v3
* Add support for occupancy-based splitk
* Fix ckProfiler dependencies
* Review fixes
* Merge branch 'explicit_bwd_weight' into 'feature/conv_bwd_weight_wmma'
Device implementation of explicit gemm for grouped conv bwd weight
See merge request amd/ai/composable_kernel!52
* Fix cmake file for tests
* fix clang format
* fix instance factory error
* Adapt all grouped conv bwd weight vanilla Xdl instances to 16x16. MRepeat doubled for all but 12 of them (some static assert failure). Also added custom reduced profiler target for building grouped conv bwd weight vanilla only profiler. Verified with gtest test.
* Revert "Adapt all grouped conv bwd weight vanilla Xdl instances to 16x16. MRepeat doubled for all but 12 of them (some static assert failure). Also added custom reduced profiler target for building grouped conv bwd weight vanilla only profiler. Verified with gtest test."
This reverts commit da8e4cfb7917d45d46339ec74eb72e2f585f14cf.
* Disable splitk for 2stage xdl on rdna (bug to be fixed)
* Fix add_test_executable
* Always ForceThreadTileTransfer for now, WaveTileTransfer does not work for convolution yet.
* Grab device and gridwise files from bkp branch, this should enable splitK support for convolution and also we no longer ForceThreadTileTransfer for explicit gemm. Also grab some updates from 7e7243783008b11e904f127ecf1df55ef95e9af2 to fix building on clang20.
* Fix bug in various bwd wei device implementations / profiler where the occupancy based split_k value could not be found because the Argument did not derive from ArgumentSplitK, leading to incorrect error tolerances.
* Actually print the reason when a device implementation is not supported.
* Print number of valid instances in profiler and tests.
* Fix clang format for Two Stage implementation
* Fix copyright
* Address review comments
* Fix explicit conv bwd weight struct
* Fix gridwise common
* Fix gridwise ab scale
* Remove autodeduce 1 stage
* Restore example tolerance calculation
* Fix compilation error
* Fix gridwise common
* Fix gridwise gemm
* Fix typo
* Fix splitk
* Fix splitk ab scale
* Adapt all grouped conv bwd weight vanilla Xdl instances to 16x16. MRepeat doubled for all but 12 of them (some static assert failure). Also added custom reduced profiler target for building grouped conv bwd weight vanilla only profiler. Verified with gtest test.
* Reduce instances to only the tuned wmma V3 ones for implicit v1 intra and explicit v1 intra pad/nopad.
* Add explicit oddMN support with custom tuned instances
* Add two stage instances based on the parameters from the tuned cshuffle V3 instances. CShuffleBlockTranserScalarPerVector adapted to 4, and mergegroups fixed to 1 for now. No more special instance lists.
* Replace cshuffle non-v3 lists with v3 lists, making sure to not have duplications. Also removing stride1pad0 support for NHWGC since we can use explicit for those cases.
* Remove some instances that give incorrect results (f16 NHWGC)
* Add bf16 f32 bf16 instances based on tuned b16 NHWGC GKYXC instances.
* Add back some generic instances to make sure we have the same shape / layout / datatype support as before the instance selection process.
* Add instances for scale and bilinear based on the bf16 NHWGC GKYXC tuning. Keep generic instances for support.
* Disable two stage f16 instances which produce incorrect results.
* Remove more instances which fail verification, for bf16_f32_bf16 and for f16 scale / bilinear.
* Disable all non-generic two-stage instances in the instance lists for NHWGC. They are never faster and support is already carried by CShuffleV3 and Explicit.
* Remove unused instance lists and related add_x_instance() functions, fwd declarations, cmakelists entries. Also merge the "wmma" and "wmma v3" instance list files, which are both v3.
* Re-enable all xdl instances (un-16x16-adapted) and dl instances. Remove custom ckProfiler target.
* Remove straggler comments
* Remove [[maybe_unused]]
* Fix clang format
* Remove unwanted instances. This includes all instances which are not NHWGCxGKYXC and F16 or BF16 (no mixed in-out types).
* Add comment
---------
Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
Co-authored-by: Kiefer van Teutem <50830967+krithalith@users.noreply.github.com>
[ROCm/composable_kernel commit: 87dd073887]
This commit is contained in:
@@ -7,7 +7,7 @@
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck/utility/functional2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit_xdl.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_explicit.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -32,17 +32,17 @@ void add_explicit_gemm_device_operation_instances(
|
||||
ck::static_for<0, std::tuple_size_v<DeviceGemmV3Ops>, 1>{}([&](auto i) {
|
||||
using DeviceGemmOp = std::tuple_element_t<i, DeviceGemmV3Ops>;
|
||||
|
||||
using NewOpInstance = DeviceGroupedConvBwdWeight_Explicit_Xdl<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
DeviceGemmOp>;
|
||||
using NewOpInstance = DeviceGroupedConvBwdWeight_Explicit<NDimSpatial,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType,
|
||||
InElementwiseOperation,
|
||||
WeiElementwiseOperation,
|
||||
OutElementwiseOperation,
|
||||
DeviceGemmOp>;
|
||||
|
||||
static_assert(std::is_base_of_v<BaseOp, NewOpInstance>,
|
||||
"wrong! NewOpInstance should be derived from BaseOp");
|
||||
|
||||
@@ -0,0 +1,138 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_grouped_conv_bwd_wei_exp_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using BF16 = bhalf_t;
|
||||
using F16 = half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Row = tensor_layout::gemm::RowMajor;
|
||||
using Col = tensor_layout::gemm::ColumnMajor;
|
||||
|
||||
template <index_t... Is>
|
||||
using S = Sequence<Is...>;
|
||||
|
||||
using PassThrough = element_wise::PassThrough;
|
||||
|
||||
static constexpr auto GemmDefault = GemmSpecialization::Default;
|
||||
static constexpr auto GemmKPadding = GemmSpecialization::KPadding;
|
||||
static constexpr auto GemmMPadding = GemmSpecialization::MPadding;
|
||||
static constexpr auto GemmMNPadding = GemmSpecialization::MNPadding;
|
||||
static constexpr auto GemmMKPadding = GemmSpecialization::MKPadding;
|
||||
static constexpr auto GemmMNKPadding = GemmSpecialization::MNKPadding;
|
||||
|
||||
static constexpr auto Intrawave = BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
|
||||
|
||||
template <typename InOutDataType>
|
||||
using device_gemm_wmma_universal_km_kn_mn_GemmDefault_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
|
||||
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
|
||||
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 48, 96, 64, 8, 8, 16, 16, 3, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 32, 64, 128, 8, 8, 16, 16, 2, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 192, 48, 96, 192, 8, 8, 16, 16, 3, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 12>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 48, 64, 64, 8, 8, 16, 16, 3, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 128, 96, 128, 64, 8, 8, 16, 16, 6, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 192, 32, 96, 192, 8, 8, 16, 16, 2, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 12>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 192, 32, 96, 192, 8, 8, 16, 16, 2, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<24, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 12>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 64, 64, 8, 8, 16, 16, 2, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 64, 64, 32, 8, 8, 16, 16, 4, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, InOutDataType, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmDefault, 64, 32, 32, 64, 8, 8, 16, 16, 2, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
|
||||
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
|
||||
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 48, 64, 128, 8, 8, 16, 16, 3, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 64, 64, 8, 8, 16, 16, 3, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 96, 64, 32, 8, 8, 16, 16, 6, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 32, 128, 8, 8, 16, 16, 3, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
// DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 96, 64, 96, 48, 8, 8, 16, 16, 4, 2, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Incorrect results for f16
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, F16, F16, Tuple<>, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_gemm_wmma_universal_km_kn_mn_GemmMNKPadding_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
|
||||
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
|
||||
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 128, 48, 64, 128, 8, 8, 16, 16, 3, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 64, 64, 8, 8, 16, 16, 4, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 64, 96, 32, 8, 8, 16, 16, 4, 3, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 64, 64, 8, 8, 16, 16, 3, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 96, 64, 32, 8, 8, 16, 16, 6, 2, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 48, 32, 128, 8, 8, 16, 16, 3, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 96, 64, 8, 8, 16, 16, 2, 3, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 96, 64, 96, 48, 8, 8, 16, 16, 4, 2, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<6, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Incorrect results for f16
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, BF16, BF16, Tuple<>, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 64, 32, 32, 128, 8, 8, 16, 16, 2, 1, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, S<16, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 4>, S<8, 8, 8>, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <typename InOutDataType, BlockGemmPipelineScheduler BlkGemmPipeSched>
|
||||
using device_gemm_wmma_universal_km_kn_mn_irregular_odd_mn_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| ALayout| BLayout| DsLayout |ELayout| ADataType| BDataType| DsDataType| CDataType| AccDataType| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransferClusterLengths| CShuffleBlockTransfer| BlockwiseGemm| BlockwiseGemm|
|
||||
//#####################################| | | | | | | | | | DataType| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| _MBlock_MPerBlock_NBlock_NPerBlock| ScalarPerVectors| Pipeline| Pipeline|
|
||||
//#####################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| | | Scheduler| Verision|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// Latency friendly
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 32, 64, 8, 8, 16, 16, 1, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 32, 64, 8, 8, 16, 16, 1, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 32, 128, 8, 8, 16, 16, 1, 2, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 48, 128, 8, 8, 16, 16, 1, 3, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 64, 8, 8, 16, 16, 1, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 64, 128, 8, 8, 16, 16, 1, 4, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 96, 64, 8, 8, 16, 16, 1, 6, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 96, 128, 8, 8, 16, 16, 1, 6, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 128, 192, 32, 8, 8, 16, 16, 1, 12, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>,
|
||||
DeviceBatchedGemmMultiD_Wmma_CShuffleV3< Col, Row, Tuple<>, Row, InOutDataType, InOutDataType, Tuple<>, F32, F32, InOutDataType, PassThrough, PassThrough, PassThrough, GemmMNKPadding, 256, 256, 96, 64, 8, 8, 16, 16, 2, 6, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 16>, S<1, 1, 1>, BlkGemmPipeSched, BlockGemmPipelineVersion::v1>
|
||||
// Memory friendly
|
||||
// TODO: add once v2 is implemented
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,91 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_two_stage_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups|
|
||||
//################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
|
||||
//################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| |
|
||||
//################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
|
||||
DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1>
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1>, // Incorrect results for at least GemmDefault
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1> // Incorrect results for at least GemmDefault
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler,
|
||||
BlockGemmPipelineVersion PipelineVersion>
|
||||
using device_grouped_conv_bwd_weight_two_stage_nhwgc_wmma_c_shuffle_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//################################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm| NumGroups|
|
||||
//################################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline| ToMerge|
|
||||
//################################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Sched| Ver| |
|
||||
//################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | | |
|
||||
DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 16, 32, 8, 16, 16, 1, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 0, 1, 1, S<1, 4, 1, 8>, 1, Scheduler, PipelineVersion, 1>
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1>,
|
||||
// DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 4, Scheduler, PipelineVersion, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,100 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
#ifdef CK_ENABLE_FP8
|
||||
using F8 = ck::f8_t;
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF8
|
||||
using BF8 = ck::bf8_t;
|
||||
#endif
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion PipelineVersion = BlockGemmPipelineVersion::v1>
|
||||
using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>
|
||||
// DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion>, // Incorrect results for at least GemmDefault
|
||||
// DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion> // Incorrect results for at least GemmDefault
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec,
|
||||
BlockGemmPipelineScheduler Scheduler = BlockGemmPipelineScheduler::Intrawave,
|
||||
BlockGemmPipelineVersion PipelineVersion = BlockGemmPipelineVersion::v1>
|
||||
using device_grouped_conv_bwd_weight_v3_wmma_c_shuffle_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#########################################| Num| InLayout| WeiLayout| OutLayout| InData| WeiData| OutData| AccData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CShuffleBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#########################################| Dim| | | | Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline|
|
||||
//#########################################| Spatial| | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version|
|
||||
//#########################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, Scheduler, PipelineVersion>
|
||||
//clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,97 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Bilinear = ck::tensor_operation::element_wise::Bilinear;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_bilinear_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline |
|
||||
//#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version |
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
// for fp16 conv.K and conv.C must be divisible by 2
|
||||
// since half_t atomic_add require scalar_per_x_vector % 2 == 0
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Presumably doesn't produce correct results for f16
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, F16, F16, F16, F32, Tuple<F16>, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Presumably doesn't produce correct results for f16
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_bilinear_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline |
|
||||
//#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version |
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
// other instances
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Verification failure
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Tuple<BLayout>, BF16, F32, BF16, F32, Tuple<F32>, PassThrough, Bilinear, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Verification failure
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,117 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_wmma_cshuffle.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
using I8 = int8_t;
|
||||
using I32 = int32_t;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <index_t NDSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
// blocksize=256
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 128, 256, 8, 8, 16, 16, 2, 8, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 128, 8, 8, 16, 16, 8, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 256, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 16>, 4>,
|
||||
// blocksize=128
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 128, 8, 8, 16, 16, 4, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 256, 8, 8, 16, 16, 1, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 32, 8, 8, 16, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// blocksize=64
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 32, 8, 8, 16, 16, 4, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 64, 8, 8, 16, 16, 1, 4, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 64, 8, 8, 16, 16, 2, 4, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 8, 8, 16, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 8, 8, 16, 16, 1, 8, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
// blocksize=32
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 32, 8, 8, 16, 16, 1, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 64, 8, 8, 16, 16, 1, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 64, 8, 8, 16, 16, 2, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 32, 8, 8, 16, 16, 2, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 32, 8, 8, 16, 16, 4, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 16, 8, 8, 16, 16, 4, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 16, 8, 8, 16, 16, 2, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename CLayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_i8_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#####################################| NumDim| A| B| C| AData| BData| CData| AccData| A| B| C| ConvForward| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//#####################################| Spatial| Layout| Layout| Layout| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeatPerWave| NRepeatPerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
//#####################################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
//#####################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 64, 4, 8, 16, 16, 2, 2, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
// blocksize=256
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 64, 256, 8, 8, 16, 16, 2, 4, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 256, 256, 64, 8, 8, 16, 16, 4, 2, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 64, 1, 4>, 8>,
|
||||
// blocksize=128
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 256, 8, 8, 16, 16, 4, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 256, 8, 8, 16, 16, 2, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 32, 256, 8, 8, 16, 16, 1, 8, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 64, 128, 8, 8, 16, 16, 2, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 128, 64, 8, 8, 16, 16, 4, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 32, 8, 8, 16, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 2, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 64, 8, 8, 16, 16, 8, 2, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 8>, 2>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 128, 256, 128, 8, 8, 16, 16, 8, 4, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 8>,
|
||||
// blocksize=64
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 32, 128, 8, 8, 16, 16, 1, 8, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 64, 128, 8, 8, 16, 16, 2, 8, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 32, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 64, 8, 8, 16, 16, 8, 2, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 64, 128, 32, 8, 8, 16, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>,
|
||||
// blocksize=32
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 16, 64, 8, 8, 16, 16, 1, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 64, 8, 8, 16, 16, 4, 4, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 32, 32, 8, 8, 16, 16, 2, 2, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle<NDSpatial, ALayout, BLayout, CLayout, I8, I8, I8, I32, PassThrough, PassThrough, PassThrough, ConvSpec, 32, 64, 16, 8, 8, 16, 16, 4, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 16, 8, 1, S<8, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 2>, 8>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,96 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/ck.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_weight_multiple_d_wmma_cshuffle_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
using Empty_Tuple = ck::Tuple<>;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using Scale = ck::tensor_operation::element_wise::Scale;
|
||||
|
||||
static constexpr auto ConvBwdWeightDefault =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvBwdWeightFilter1x1Stride1Pad0 =
|
||||
ck::tensor_operation::device::ConvolutionBackwardWeightSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_c_shuffle_f16_scale_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline |
|
||||
//#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version |
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 4, 1, 1, 1, S<1, 16, 1, 4>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
// for fp16 conv.K and conv.C must be divisible by 2
|
||||
// since half_t atomic_add require scalar_per_x_vector % 2 == 0
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Presumably doesn't produce correct results for fp16
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, F16, F16, F16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Presumably doesn't produce correct results for fp16
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <ck::index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardWeightSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_weight_wmma_c_shuffle_bf16_scale_instances = std::tuple<
|
||||
// clang-format off
|
||||
//#################################################| Num| InLayout| WeiLayout| OutLayout| DsLayout| InData| WeiData| OutData| AccData| DsData| In| Wei| Out| ConvBackward| Block| MPer| NPer| KPer| ABK1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CShuffleBlockTransfer| CBlockTransfer| BlockGemm| BlockGemm|
|
||||
//#################################################| Dim| | | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Weight| Size| Block| Block| Block| | Wmma| Wmma| | | ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MRepeat| NRepeat| ClusterLengths| ScalarPerVector| Pipeline| Pipeline |
|
||||
//#################################################| Spatial| | | | | | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| MBlock_MPerBlock| _NPerBlock| Scheduler| Version |
|
||||
//#################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | NBlock_NPerBlock| | | |
|
||||
// generic instance
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 1, 4, 1, 1, 1, S<1, 16, 1, 4>, 1, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
// other instances
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 32, 32, 32, 8, 16, 16, 2, 1, S<4, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, S<4, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 2, 2, 0, 1, 1, S<1, 8, 1, 8>, 2, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 128, 128, 32, 8, 16, 16, 8, 2, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<4, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 64, 64, 64, 64, 8, 16, 16, 4, 2, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<8, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 4>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 128, 256, 64, 8, 16, 16, 8, 2, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 4, 8, 1, S<8, 32, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 48, 64, 128, 8, 16, 16, 3, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 96, 128, 64, 8, 16, 16, 6, 2, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<8, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 128, 64, 64, 128, 8, 16, 16, 4, 1, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, S<16, 8, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 256, 96, 128, 128, 8, 16, 16, 6, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<16, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 8, 8, 0, 1, 1, S<1, 16, 1, 16>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1>, // Verification failure
|
||||
// DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< NDimSpatial, ALayout, BLayout, ELayout, Empty_Tuple, BF16, F32, BF16, F32, Empty_Tuple, PassThrough, Scale, PassThrough, ConvSpec, 96, 96, 96, 48, 8, 16, 16, 6, 2, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 1, S<6, 16, 1>, S<2, 0, 1>, S<1, 0, 2>, 1, 6, 8, 0, 1, 1, S<1, 16, 1, 6>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v1> // Verification failure
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -21,6 +21,7 @@
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
#include "grouped_convolution_backward_weight_wmma.inc"
|
||||
#include "grouped_convolution_backward_weight_explicit_wmma.inc"
|
||||
#endif
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
@@ -414,21 +415,24 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_f16_pipev5_irregular_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -471,23 +475,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_xdl_nhwgc_gkyxc_nhwgk_bf16_pipev5_irregular_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
@@ -678,21 +682,24 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_f16_pipev5_irregular_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -735,23 +742,23 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_pipev5_irregular_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
@@ -850,35 +857,53 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_USE_WMMA
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
if constexpr(NumDimSpatial == 2)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, GNDHWC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, GNDHWK>)
|
||||
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
|
||||
is_same_v<OutLayout, NHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeA, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeB, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
@@ -889,26 +914,40 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
else if constexpr(is_same_v<InDataType, int8_t> && is_same_v<WeiDataType, int8_t> &&
|
||||
is_same_v<OutDataType, int8_t> &&
|
||||
is_same_v<ComputeTypeA, int8_t> &&
|
||||
is_same_v<ComputeTypeB, int8_t>)
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> &&
|
||||
is_same_v<WeiDataType, ck::bhalf_t> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeA, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeB, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances(
|
||||
op_ptrs);
|
||||
// Explicit GEMM
|
||||
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return op_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -17,6 +17,39 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<GKZYXC>,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16>,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<GKZYXC>,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
Tuple<F32>,
|
||||
PassThrough,
|
||||
Bilinear,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
@@ -148,6 +181,35 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeA, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeB, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_bilinear_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// 2D
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// 3D
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_wmma_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -10,7 +10,7 @@ namespace instance {
|
||||
// 2D
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -22,7 +22,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -34,7 +34,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_ins
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -46,7 +46,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -58,7 +58,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_i
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -70,7 +70,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -82,7 +82,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_i
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -94,7 +94,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -106,7 +106,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -121,7 +121,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -133,7 +133,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -145,7 +145,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -157,7 +157,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -169,7 +169,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -181,7 +181,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -193,7 +193,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -205,7 +205,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -217,7 +217,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
@@ -232,7 +232,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
// 3D
|
||||
#ifdef CK_ENABLE_BF16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -244,7 +244,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_default_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -256,7 +256,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_comp_mnkpadding_ins
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -268,7 +268,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -280,7 +280,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v1_mnkpadding_i
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -292,7 +292,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_default_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -304,7 +304,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_mem_v2_mnkpadding_i
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -316,7 +316,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -328,7 +328,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_m_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -343,7 +343,7 @@ void add_device_grouped_convnd_bwd_weight_bf16_bf16_bf16_exp_odd_n_instances(
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -355,7 +355,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_default_instances
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_comp_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -367,7 +367,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_comp_mnkpadding_instan
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -379,7 +379,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_default_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v1_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -391,7 +391,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v1_mnkpadding_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_default_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -403,7 +403,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_default_instanc
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_mem_v2_mnkpadding_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -415,7 +415,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_mem_v2_mnkpadding_inst
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_mn_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -427,7 +427,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_mn_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_m_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -439,7 +439,7 @@ void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_m_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_convnd_bwd_weight_f16_f16_f16_exp_odd_n_instances(
|
||||
void add_device_grouped_convnd_bwd_weight_xdl_f16_f16_f16_exp_odd_n_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
|
||||
@@ -17,6 +17,40 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
#ifdef CK_USE_WMMA
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<>,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<>,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
F32,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
PassThrough,
|
||||
Scale,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
@@ -147,6 +181,34 @@ struct DeviceOperationInstanceFactory<
|
||||
static auto GetInstances()
|
||||
{
|
||||
std::vector<std::unique_ptr<DeviceOp>> op_ptrs;
|
||||
#ifdef CK_USE_WMMA
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
{
|
||||
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
|
||||
is_same_v<OutLayout, NDHWGK>)
|
||||
{
|
||||
#ifdef CK_ENABLE_FP16
|
||||
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
|
||||
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
|
||||
is_same_v<ComputeTypeB, half_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
if constexpr(is_same_v<InDataType, ck::bhalf_t> && is_same_v<WeiDataType, float> &&
|
||||
is_same_v<OutDataType, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeA, ck::bhalf_t> &&
|
||||
is_same_v<ComputeTypeB, ck::bhalf_t>)
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_weight_wmma_scale_ndhwgc_gkzyxc_ndhwgk_bf16_f32_bf16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifdef CK_USE_XDL
|
||||
if constexpr(NumDimSpatial == 3)
|
||||
|
||||
@@ -8,32 +8,61 @@ namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
// conv2d backward weight
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_weight_wmma_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_weight_two_stage_wmma_nhwgc_gkyxc_nhwgk_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
// conv3d backward weight
|
||||
#ifdef CK_ENABLE_FP16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_f16_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
F16,
|
||||
F16,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
@@ -46,7 +75,7 @@ void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_f16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
@@ -58,51 +87,28 @@ void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_f16_1x1s1p0_
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_INT8
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_gndhwc_gkzyxc_gndhwk_i8_1x1s1p0_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
GNDHWC,
|
||||
GKZYXC,
|
||||
GNDHWK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_instances(
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_weight_wmma_ndhwgc_gkzyxc_ndhwgk_i8_1x1s1p0_instances(
|
||||
void add_device_grouped_conv3d_bwd_weight_two_stage_wmma_ndhwgc_gkzyxc_ndhwgk_bf16_pipev1_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeight<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
NDHWGK,
|
||||
int8_t,
|
||||
int8_t,
|
||||
int8_t,
|
||||
BF16,
|
||||
BF16,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
Reference in New Issue
Block a user