mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Integrate universal gemm with conv bwd data and add SplitK (#1315)
* Integrate universal gemm with conv bwd data * Fix multi d kernel * Add splitK support * instances refactor * instances refactor * refactor * fixeS * fixes * 16x16 instnaces * Fixes * Fix * Fix * Fix * Fix * Fix * Fixes * fix * fix
This commit is contained in:
@@ -15,6 +15,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added support for Stream-K version of mixed fp8/bf16 GEMM
|
||||
* Added GEMM pipeline for microscaling (MX) data types
|
||||
* Added support for FP16 2:4 structured sparsity to universal GEMM.
|
||||
* Added support for Split K for grouped convolution backward data.
|
||||
|
||||
### Optimized
|
||||
|
||||
|
||||
4
Jenkinsfile
vendored
4
Jenkinsfile
vendored
@@ -937,8 +937,8 @@ pipeline {
|
||||
environment{
|
||||
setup_args = "NO_CK_BUILD"
|
||||
execute_args = """ ../script/cmake-ck-dev.sh ../ gfx90a && \
|
||||
make -j64 test_grouped_convnd_fwd_large_cases_xdl && \
|
||||
./bin/test_grouped_convnd_fwd_large_cases_xdl"""
|
||||
make -j64 test_grouped_convnd_fwd_large_cases_xdl test_grouped_convnd_bwd_data_xdl_large_cases && \
|
||||
./bin/test_grouped_convnd_fwd_large_cases_xdl && ./bin/test_grouped_convnd_bwd_data_xdl_large_cases"""
|
||||
}
|
||||
steps{
|
||||
buildHipClangJobAndReboot(setup_args:setup_args, no_reboot:true, build_type: 'Release', execute_cmd: execute_args)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -59,7 +59,8 @@ struct DeviceGroupedConvBwdDataMultipleD : public BaseOperator
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOperation& a_element_op,
|
||||
const BElementwiseOperation& b_element_op,
|
||||
const CDEElementwiseOperation& cde_element_op) = 0;
|
||||
const CDEElementwiseOperation& cde_element_op,
|
||||
const ck::index_t split_k = 1) = 0;
|
||||
|
||||
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -227,7 +227,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOp& a_element_op,
|
||||
const BElementwiseOp& b_element_op,
|
||||
const CDEElementwiseOp& cde_element_op)
|
||||
const CDEElementwiseOp& cde_element_op,
|
||||
const ck::index_t split_k = 1)
|
||||
: p_a_grid_{static_cast<const ADataType*>(p_a)},
|
||||
p_b_grid_{static_cast<const BDataType*>(p_b)},
|
||||
p_ds_grid_{},
|
||||
@@ -240,7 +241,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
b_g_k_c_xs_lengths_{b_g_k_c_xs_lengths},
|
||||
conv_filter_strides_{conv_filter_strides},
|
||||
input_left_pads_{input_left_pads},
|
||||
input_right_pads_{input_right_pads}
|
||||
input_right_pads_{input_right_pads},
|
||||
k_batch_{split_k}
|
||||
{
|
||||
// populate Ds pointer
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
@@ -445,6 +447,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
std::array<index_t, NDimSpatial> conv_filter_strides_;
|
||||
std::array<index_t, NDimSpatial> input_left_pads_;
|
||||
std::array<index_t, NDimSpatial> input_right_pads_;
|
||||
|
||||
const index_t k_batch_;
|
||||
};
|
||||
|
||||
// Invoker
|
||||
@@ -534,6 +538,11 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
|
||||
static bool IsSupportedArgument(const Argument& arg)
|
||||
{
|
||||
if(arg.k_batch_ != 1)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// check device
|
||||
if(ck::is_gfx11_supported() || ck::is_gfx12_supported())
|
||||
{
|
||||
@@ -691,7 +700,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOp& a_element_op,
|
||||
const BElementwiseOp& b_element_op,
|
||||
const CDEElementwiseOp& cde_element_op)
|
||||
const CDEElementwiseOp& cde_element_op,
|
||||
const ck::index_t split_k = 1)
|
||||
{
|
||||
return Argument{p_a,
|
||||
p_b,
|
||||
@@ -711,7 +721,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
input_right_pads,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op};
|
||||
cde_element_op,
|
||||
split_k};
|
||||
}
|
||||
|
||||
static auto MakeInvoker() { return Invoker{}; }
|
||||
@@ -737,7 +748,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
const std::array<index_t, NDimSpatial>& input_right_pads,
|
||||
const AElementwiseOp& a_element_op,
|
||||
const BElementwiseOp& b_element_op,
|
||||
const CDEElementwiseOp& cde_element_op) override
|
||||
const CDEElementwiseOp& cde_element_op,
|
||||
const ck::index_t split_k = 1) override
|
||||
{
|
||||
return std::make_unique<Argument>(p_a,
|
||||
p_b,
|
||||
@@ -757,7 +769,8 @@ struct DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
input_right_pads,
|
||||
a_element_op,
|
||||
b_element_op,
|
||||
cde_element_op);
|
||||
cde_element_op,
|
||||
split_k);
|
||||
}
|
||||
|
||||
std::unique_ptr<BaseInvoker> MakeInvokerPointer() override
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,7 @@
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
#include "ck/tensor_operation/operator_transform/transform_conv_bwd_weight_to_gemm_v2.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_bwd_weight_v3.hpp"
|
||||
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_conv_v3.hpp"
|
||||
#include <ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp>
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
|
||||
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -187,7 +187,8 @@ struct TransformConvBwdDataToGemm_v1
|
||||
WTilde_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.WTilde_)},
|
||||
ZDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.ZDot_)},
|
||||
YDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.YDot_)},
|
||||
XDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XDot_)}
|
||||
XDot_{static_cast<IndexType>(transform_conv_bwd_data_to_gemm_base.XDot_)},
|
||||
batch_k_{transform_conv_bwd_data_to_gemm_base.batch_k_}
|
||||
{
|
||||
}
|
||||
|
||||
@@ -203,7 +204,8 @@ struct TransformConvBwdDataToGemm_v1
|
||||
const ConvSpatialDimsType& conv_filter_dilations,
|
||||
const ConvSpatialDimsType& input_left_pads,
|
||||
const ConvSpatialDimsType& input_right_pads,
|
||||
const ConvSpatialDimsType& tildes)
|
||||
const ConvSpatialDimsType& tildes,
|
||||
const index_t batch_k = 1)
|
||||
: Hi_{c_g_n_c_wis_lengths[HIdx]},
|
||||
Wi_{c_g_n_c_wis_lengths[WIdx]},
|
||||
Ho_{a_g_n_k_wos_lengths[HIdx]},
|
||||
@@ -231,7 +233,8 @@ struct TransformConvBwdDataToGemm_v1
|
||||
InRightPadH_{input_right_pads[HIdx - NonSpatialDimsNum]},
|
||||
InRightPadW_{input_right_pads[WIdx - NonSpatialDimsNum]},
|
||||
IdxYTilde_{tildes[YIdx - NonSpatialDimsNum]},
|
||||
IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]}
|
||||
IdxXTilde_{tildes[XIdx - NonSpatialDimsNum]},
|
||||
batch_k_{batch_k}
|
||||
{
|
||||
static_assert(is_same_v<ConvSpatialDimsType, std::array<IndexType, NDimSpatial>> ||
|
||||
is_same_v<ConvSpatialDimsType, ck::Array<IndexType, NDimSpatial>>);
|
||||
@@ -616,20 +619,22 @@ struct TransformConvBwdDataToGemm_v1
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t AK0 = math::integer_divide_ceil(K_, AK1);
|
||||
const index_t K0PerBlock = GemmKPerBlock / AK1;
|
||||
const index_t AK0 =
|
||||
math::integer_divide_ceil(K_, AK1 * K0PerBlock * batch_k_) * K0PerBlock;
|
||||
|
||||
// A: output tensor
|
||||
const auto out_gemmak0_gemmmraw_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_grid_desc,
|
||||
make_tuple(make_pass_through_transform(N_ * Do_ * Ho_ * Wo_),
|
||||
make_unmerge_transform(make_tuple(AK0, AK1))),
|
||||
make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<1>{}, Sequence<0, 2>{}));
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
out_gemmak0_gemmmraw_gemmak1_grid_desc,
|
||||
make_tuple(AK0, GemmMPerBlock, AK1),
|
||||
make_tuple(AK0 * batch_k_, GemmMPerBlock, AK1),
|
||||
Sequence<false, DoPadGemmM, false>{});
|
||||
|
||||
return out_gemmak0_gemmm_gemmak1_grid_desc;
|
||||
@@ -719,11 +724,15 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(GemmKPerBlock, GemmMPerBlock),
|
||||
Sequence<true, DoPadGemmM>{});
|
||||
|
||||
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
|
||||
const index_t K0PerBlock = GemmKPerBlock / AK1;
|
||||
const index_t AK0 =
|
||||
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
|
||||
AK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk_gemmm_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
|
||||
make_pass_through_transform(
|
||||
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
@@ -816,11 +825,15 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(GemmKPerBlock, GemmMPerBlock),
|
||||
Sequence<true, DoPadGemmM>{});
|
||||
|
||||
const index_t AK0 = out_gemmk_gemmm_padded_grid_desc.GetLength(I0) / AK1;
|
||||
const index_t K0PerBlock = GemmKPerBlock / AK1;
|
||||
const index_t AK0 =
|
||||
math::integer_divide_ceil(out_gemmk_gemmm_padded_grid_desc.GetLength(I0),
|
||||
AK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto out_gemmak0_gemmm_gemmak1_grid_desc = transform_tensor_descriptor(
|
||||
out_gemmk_gemmm_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0, AK1)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(AK0 * batch_k_, AK1)),
|
||||
make_pass_through_transform(
|
||||
out_gemmk_gemmm_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
@@ -850,21 +863,23 @@ struct TransformConvBwdDataToGemm_v1
|
||||
ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::
|
||||
Filter1x1Stride1Pad0)
|
||||
{
|
||||
const index_t BK0 = math::integer_divide_ceil(K_, BK1);
|
||||
const index_t K0PerBlock = GemmKPerBlock / BK1;
|
||||
const index_t BK0 =
|
||||
math::integer_divide_ceil(K_, BK1 * K0PerBlock * batch_k_) * K0PerBlock;
|
||||
|
||||
// B: weight tensor
|
||||
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc =
|
||||
transform_tensor_descriptor(make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
const auto wei_gemmbk0_gemmnraw_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
make_naive_tensor_descriptor_packed(make_tuple(K_, C_)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
|
||||
make_pass_through_transform(C_)),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
make_tuple(Sequence<0, 2>{}, Sequence<1>{}));
|
||||
make_naive_tensor_descriptor(make_tuple(N_ * Do_ * Ho_ * Wo_, C_), make_tuple(I0, I1));
|
||||
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc =
|
||||
ck::tensor_operation::device::PadTensorDescriptor(
|
||||
wei_gemmbk0_gemmnraw_gemmbk1_grid_desc,
|
||||
make_tuple(BK0, GemmNPerBlock, BK1),
|
||||
make_tuple(BK0 * batch_k_, GemmNPerBlock, BK1),
|
||||
Sequence<false, DoPadGemmN, false>{});
|
||||
|
||||
return wei_gemmbk0_gemmn_gemmbk1_grid_desc;
|
||||
@@ -925,11 +940,15 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(GemmKPerBlock, GemmNPerBlock),
|
||||
Sequence<true, DoPadGemmN>{});
|
||||
|
||||
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
|
||||
const index_t K0PerBlock = GemmKPerBlock / BK1;
|
||||
const index_t BK0 =
|
||||
math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0),
|
||||
BK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto wei_gemmbk0_gemmn_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
|
||||
make_pass_through_transform(
|
||||
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
@@ -1006,11 +1025,15 @@ struct TransformConvBwdDataToGemm_v1
|
||||
make_tuple(GemmKPerBlock, GemmNPerBlock),
|
||||
Sequence<true, DoPadGemmN>{});
|
||||
|
||||
const index_t BK0 = wei_gemmk_gemmn_padded_grid_desc.GetLength(I0) / BK1;
|
||||
const index_t K0PerBlock = GemmKPerBlock / BK1;
|
||||
const index_t BK0 =
|
||||
math::integer_divide_ceil(wei_gemmk_gemmn_padded_grid_desc.GetLength(I0),
|
||||
BK1 * K0PerBlock * batch_k_) *
|
||||
K0PerBlock;
|
||||
|
||||
const auto wei_gemmbk0_gemm_gemmbk1_grid_desc = transform_tensor_descriptor(
|
||||
wei_gemmk_gemmn_padded_grid_desc,
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0, BK1)),
|
||||
make_tuple(make_unmerge_transform(make_tuple(BK0 * batch_k_, BK1)),
|
||||
make_pass_through_transform(
|
||||
wei_gemmk_gemmn_padded_grid_desc.GetLength(I1))),
|
||||
make_tuple(Sequence<0>{}, Sequence<1>{}),
|
||||
@@ -1355,6 +1378,7 @@ struct TransformConvBwdDataToGemm_v1
|
||||
IndexType ZTilde_, YTilde_, XTilde_;
|
||||
IndexType DTilde_, HTilde_, WTilde_;
|
||||
IndexType ZDot_, YDot_, XDot_;
|
||||
index_t batch_k_;
|
||||
};
|
||||
|
||||
} // namespace tensor_operation
|
||||
|
||||
@@ -54,6 +54,28 @@ using device_grouped_conv_bwd_data_xdl_f16_generic_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_f16_16_16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -73,7 +95,7 @@ using device_grouped_conv_bwd_data_xdl_f16_instances =
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
@@ -108,6 +130,27 @@ using device_grouped_conv_bwd_data_xdl_bf16_generic_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_bf16_16_16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 2, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -126,7 +169,7 @@ using device_grouped_conv_bwd_data_xdl_bf16_instances = std::tuple<
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<NDimSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, Empty_Tuple, BF16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 8>, 8>,
|
||||
@@ -162,6 +205,28 @@ using device_grouped_conv_bwd_data_xdl_f32_generic_instances =
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionBackwardDataSpecialization ConvSpec>
|
||||
using device_grouped_conv_bwd_data_xdl_f32_16_16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
// ##############################################| NDim| ALayout| BLayout| DsLayout| ELayout| AData| BData| AccData| CShuffle| DsData| EData| AElementwise| BElementwise| CDEElementwise| ConvolutionBackward| DoPad| DoPad| NumGemmK| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffleMXdl| CShuffleNXdl| CDEBlockTransfer| CDEBlockTransfer|
|
||||
// ##############################################| Spatial| | | | | Type| Type| Type| DataType| Type| Type| Operation| Operation| Operation| DataSpecialization| GemmM| GemmN| PrefetchStage| Size| Block| Block| Block| | | XDL| XDL| PerWave| PerWave| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| ExtraN| PerWave| PerWave| _MBlock_MPerBlock| ScalarPerVector|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | Lengths_AK0_M_AK1| ArrangeOrder| | | PerVector| PerVector_AK1| | Lengths_BK0_N_BK1| ArrangeOrder| | | PerVector| PerVector_BK1| | PerShuffle| PerShuffle| _NBlock_NPerBlock| _NPerBlock|
|
||||
// ##############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 16, 64, 32, 8, 8, 16, 16, 1, 4, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 16, 32, 8, 8, 16, 16, 4, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 4, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDimSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -181,7 +246,7 @@ using device_grouped_conv_bwd_data_xdl_f32_instances =
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 0, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
@@ -194,7 +259,7 @@ using device_grouped_conv_bwd_data_xdl_f32_instances =
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 32, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 64, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 4>, 4>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 1, 1, 1, S<1, 16, 1, 4>, 8>
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F32, F32, F32, F32, Empty_Tuple, F32, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 64, 32, 64, 32, 8, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 8, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 1, 1, 1, S<1, 16, 1, 4>, 4>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
@@ -218,7 +283,7 @@ using device_grouped_conv_bwd_data_xdl_input_fp16_comp_bf8f8_instances =
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 64, 128, 32, 8, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 32, 32, 8, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 8, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 1, 4, 1, 1, 1, S<1, 32, 1, 4>, 1, LoopScheduler::Default, BF8, F8>,
|
||||
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 2, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, 0, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 256, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 256, 128, 256, 32, 8, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 64, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 32, 1, 8>, 4, LoopScheduler::Default, BF8, F8>,
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1< NDimSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F32, Empty_Tuple, F16, PassThrough, PassThrough, PassThrough, ConvSpec, true, true, 1, 128, 128, 128, 32, 8, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, 1, S<4, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 4, 1, 1, 1, S<1, 16, 1, 8>, 4, LoopScheduler::Default, BF8, F8>,
|
||||
|
||||
@@ -109,6 +109,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
@@ -117,6 +119,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -126,6 +130,8 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -167,6 +173,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<ComputeTypeB, F16>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
@@ -177,6 +185,8 @@ struct DeviceOperationInstanceFactory<
|
||||
is_same_v<ComputeTypeB, F32>)
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
@@ -188,6 +198,8 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
@@ -237,6 +249,8 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
@@ -255,6 +269,8 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
@@ -264,6 +280,8 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -308,6 +326,8 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
@@ -319,6 +339,8 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
@@ -330,6 +352,8 @@ struct DeviceOperationInstanceFactory<
|
||||
{
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
|
||||
@@ -69,6 +69,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
@@ -84,6 +98,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
@@ -99,6 +127,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
|
||||
#ifdef CK_ENABLE_FP16
|
||||
@@ -162,6 +204,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
@@ -191,6 +247,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
@@ -220,6 +290,20 @@ void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
@@ -295,6 +379,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP32
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
@@ -310,6 +408,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#ifdef CK_ENABLE_BF16
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
@@ -325,6 +437,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
#endif
|
||||
#if defined CK_ENABLE_FP16 && defined CK_ENABLE_FP8 && defined CK_ENABLE_BF8
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
@@ -403,6 +529,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
GKCZYX,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
@@ -432,6 +572,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
GKCZYX,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
@@ -461,6 +615,20 @@ void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances(
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
GKCZYX,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
|
||||
@@ -7,9 +7,15 @@ add_instance_library(
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_bf16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_nhwgc_gkyxc_nhwgk_f32_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkyxc_ngkhw_f32_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_16_16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv2d_bwd_data_xdl_ngchw_gkcyx_ngkhw_f32_instance.cpp
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
@@ -8,7 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_bf16_instances(
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataDefault>{});
|
||||
device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
@@ -8,7 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f16_instances(
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataDefault>{});
|
||||
device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
@@ -8,7 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[g, n, hi, wi, c] * wei[g, k, y, x, c] = in[g, n, ho, wo, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
GNHWK,
|
||||
@@ -26,21 +26,21 @@ void add_device_grouped_conv2d_bwd_data_xdl_gnhwk_gkyxc_gnhwc_f32_instances(
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataDefault>{});
|
||||
device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2,
|
||||
GNHWK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
GNHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2,
|
||||
NGKHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,7 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_bf16_vec_transpose_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2,
|
||||
NGKHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,7 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f16_vec_transpose_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2,
|
||||
NGKHW,
|
||||
GKCYX,
|
||||
Empty_Tuple,
|
||||
NGCHW,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,7 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkcyx_ngchw_f32_vec_transpose_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
|
||||
@@ -9,7 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_ngkhw_gkyxc_ngchw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NGKHW,
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -8,7 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_16_16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -8,7 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_16_16_instances<2,
|
||||
NHWGK,
|
||||
GKYXC,
|
||||
Empty_Tuple,
|
||||
NHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -8,7 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, ho, wo, g, c] * wei[g, k, y, x, c] = in[n, hi, wi, g, k]
|
||||
|
||||
void add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<2,
|
||||
NHWGK,
|
||||
|
||||
@@ -6,15 +6,22 @@ set(GROUPED_CONV3D_BWD_DATA
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_bf16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkzyxc_ngkdhw_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_16_16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f16_vec_transpose_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_bf16_vec_transpose_instance.cpp
|
||||
xdl/device_grouped_conv3d_bwd_data_xdl_ngcdhw_gkczyx_ngkdhw_f32_vec_transpose_instance.cpp
|
||||
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
wmma/device_grouped_conv3d_bwd_data_wmma_gndhwc_gkzyxc_gndhwk_i8_instance.cpp
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
@@ -8,7 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho,
|
||||
|
||||
// wo, k]
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
@@ -27,21 +27,21 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_bf16_instances(
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
ConvBwdDataDefault>{});
|
||||
device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
@@ -8,7 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho,
|
||||
|
||||
// wo, k]
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
@@ -27,21 +27,21 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f16_instances(
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
ConvBwdDataDefault>{});
|
||||
device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
@@ -8,7 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[g, n, di, hi, wi, c] * wei[g, k, z, y, x, c] = in[g, n, do, ho,
|
||||
|
||||
// wo, k]
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
@@ -27,21 +27,21 @@ void add_device_grouped_conv3d_bwd_data_xdl_gndhwk_gkzyxc_gndhwc_f32_instances(
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
ConvBwdDataDefault>{});
|
||||
device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3,
|
||||
GNDHWK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
GNDHWC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
@@ -8,8 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
@@ -8,8 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
|
||||
@@ -0,0 +1,49 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
// 1. Default
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataDefault>{});
|
||||
// 2. Filter1x1Stride1Pad0
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3,
|
||||
NDHWGK,
|
||||
GKZYXC,
|
||||
Empty_Tuple,
|
||||
NDHWGC,
|
||||
ConvBwdDataFilter1x1Stride1Pad0>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
@@ -8,8 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
@@ -8,8 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_input_f16_comp_bf8f8_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
GKCZYX,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
BF16,
|
||||
BF16,
|
||||
Empty_Tuple,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_bf16_16_16_instances<3,
|
||||
NGKDHW,
|
||||
GKCZYX,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,8 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
|
||||
@@ -9,8 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_bf16_vec_transpose_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
GKCZYX,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
F16,
|
||||
F16,
|
||||
Empty_Tuple,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f16_16_16_instances<3,
|
||||
NGKDHW,
|
||||
GKCZYX,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,8 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
|
||||
@@ -9,8 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f16_vec_transpose_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_transpose_xdl_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_16_16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
GKCZYX,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
F32,
|
||||
F32,
|
||||
Empty_Tuple,
|
||||
F32,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
PassThrough>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_bwd_data_xdl_f32_16_16_instances<3,
|
||||
NGKDHW,
|
||||
GKCZYX,
|
||||
Empty_Tuple,
|
||||
NGCDHW,
|
||||
ConvBwdDataDefault>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -9,8 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
|
||||
@@ -9,8 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkczyx_ngcdhw_f32_vec_transpose_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
|
||||
@@ -9,8 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
|
||||
@@ -9,8 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
|
||||
@@ -9,8 +9,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_ngkdhw_gkzyxc_ngcdhw_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NGKDHW,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp"
|
||||
@@ -8,8 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp"
|
||||
@@ -8,8 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_bilinear_instance.hpp"
|
||||
@@ -8,8 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp"
|
||||
@@ -8,8 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp"
|
||||
@@ -8,8 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_bwd_data/device_grouped_conv_bwd_data_xdl_scale_instance.hpp"
|
||||
@@ -8,8 +8,7 @@ namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
// Compilation parameters for out[n, di, hi, wi, g, c] * wei[g, k, z, y, x, c] = in[n, do, ho, wo,
|
||||
// g, k]
|
||||
|
||||
void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
|
||||
NDHWGK,
|
||||
|
||||
@@ -34,7 +34,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
int init_method,
|
||||
bool do_log,
|
||||
bool time_kernel,
|
||||
const ck::utils::conv::ConvParam& conv_param)
|
||||
const ck::utils::conv::ConvParam& conv_param,
|
||||
ck::index_t split_k = 1)
|
||||
{
|
||||
using OutElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using WeiElementOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
@@ -88,6 +89,7 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
// reset input to zero
|
||||
in_device_buf.SetZero();
|
||||
|
||||
float max_accumulated_value = 0;
|
||||
if(do_verification)
|
||||
{
|
||||
auto ref_conv = ck::tensor_operation::host::ReferenceConvBwdData<NDimSpatial,
|
||||
@@ -114,17 +116,19 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
in_element_op);
|
||||
|
||||
ref_invoker.Run(ref_argument);
|
||||
max_accumulated_value = *std::max_element(in_host.mData.begin(), in_host.mData.end());
|
||||
}
|
||||
|
||||
std::string best_op_name;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
float best_avg_time = 0;
|
||||
float best_tflops = 0;
|
||||
float best_gb_per_sec = 0;
|
||||
ck::index_t best_split_k = 1;
|
||||
|
||||
// profile device op instances
|
||||
bool pass = true;
|
||||
|
||||
auto run_impl = [&](auto& op_ptr, auto& argument_ptr) {
|
||||
auto run_impl = [&](auto& op_ptr, auto& argument_ptr, const index_t& split_k_for_run) {
|
||||
// workspace_sz will be equal to 0 for other layout than NGCHW
|
||||
const std::size_t workspace_sz = op_ptr->GetWorkSpaceSize(argument_ptr.get());
|
||||
DeviceMem workspace_dev(workspace_sz);
|
||||
@@ -150,7 +154,8 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
float gb_per_sec = num_btype / 1.E6 / avg_time;
|
||||
|
||||
std::cout << "Perf: " << std::setw(10) << avg_time << " ms, " << tflops << " TFlops, "
|
||||
<< gb_per_sec << " GB/s, " << op_name << std::endl;
|
||||
<< gb_per_sec << " GB/s, " << op_name << ", SplitK " << split_k_for_run
|
||||
<< std::endl;
|
||||
|
||||
if(tflops > best_tflops)
|
||||
{
|
||||
@@ -158,13 +163,39 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
best_tflops = tflops;
|
||||
best_avg_time = avg_time;
|
||||
best_gb_per_sec = gb_per_sec;
|
||||
best_split_k = split_k_for_run;
|
||||
}
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
in_device_buf.FromDevice(in_device.mData.data());
|
||||
|
||||
pass = pass & ck::utils::check_err(in_device, in_host);
|
||||
using ComputeType = std::conditional_t<sizeof(OutDataType) < sizeof(WeiDataType),
|
||||
OutDataType,
|
||||
WeiDataType>;
|
||||
using AccDataType =
|
||||
std::conditional_t<std::is_same_v<ComputeType, int8_t>, int32_t, float>;
|
||||
const index_t num_accums = conv_param.K_;
|
||||
// Calculate thresholds
|
||||
auto rtol = ck::utils::get_relative_threshold<ComputeType, InDataType, AccDataType>(
|
||||
num_accums / split_k_for_run);
|
||||
auto atol = ck::utils::get_absolute_threshold<ComputeType, InDataType, AccDataType>(
|
||||
max_accumulated_value / split_k_for_run, num_accums / split_k_for_run);
|
||||
// Calculate error due to split_k accumulation
|
||||
auto rtol_split_k =
|
||||
ck::utils::get_relative_threshold<InDataType, InDataType, InDataType>(
|
||||
split_k_for_run);
|
||||
auto atol_split_k =
|
||||
ck::utils::get_absolute_threshold<InDataType, InDataType, InDataType>(
|
||||
max_accumulated_value, split_k_for_run);
|
||||
// Use higher threshold
|
||||
rtol = std::max(rtol, rtol_split_k);
|
||||
atol = std::max(atol, atol_split_k);
|
||||
|
||||
pass = pass & ck::utils::check_err(
|
||||
in_device, in_host, "Error: Incorrect results!", rtol, atol);
|
||||
std::cout << "Relative error threshold: " << rtol
|
||||
<< " Absolute error threshold: " << atol << std::endl;
|
||||
|
||||
if(do_log)
|
||||
{
|
||||
@@ -225,35 +256,47 @@ bool profile_grouped_conv_bwd_data_impl(int do_verification,
|
||||
copy(conv_param.input_left_pads_, input_left_pads);
|
||||
copy(conv_param.input_right_pads_, input_right_pads);
|
||||
|
||||
std::vector<ck::index_t> split_k_list = {1, 2, 4, 8, 16, 32, 64, 128};
|
||||
|
||||
if(split_k > 0)
|
||||
{
|
||||
split_k_list = {split_k};
|
||||
}
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
{
|
||||
auto argument_ptr =
|
||||
op_ptr->MakeArgumentPointer(static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
{},
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
out_element_op,
|
||||
wei_element_op,
|
||||
in_element_op);
|
||||
for(std::size_t split_k_id = 0; split_k_id < split_k_list.size(); split_k_id++)
|
||||
{
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(
|
||||
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
|
||||
static_cast<WeiDataType*>(wei_device_buf.GetDeviceBuffer()),
|
||||
{},
|
||||
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
|
||||
out_lengths,
|
||||
out_strides,
|
||||
wei_lengths,
|
||||
wei_strides,
|
||||
{},
|
||||
{},
|
||||
in_lengths,
|
||||
in_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads,
|
||||
out_element_op,
|
||||
wei_element_op,
|
||||
in_element_op,
|
||||
split_k_list[split_k_id]);
|
||||
|
||||
run_impl(op_ptr, argument_ptr);
|
||||
run_impl(op_ptr, argument_ptr, split_k_list[split_k_id]);
|
||||
}
|
||||
}
|
||||
|
||||
std::cout << "Best configuration parameters:"
|
||||
<< "\nname: " << best_op_name << "\navg_time: " << best_avg_time
|
||||
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << std::endl;
|
||||
<< "\ntflops: " << best_tflops << "\nGB/s: " << best_gb_per_sec << ", SplitK "
|
||||
<< best_split_k << std::endl;
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
@@ -68,8 +68,8 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
const bool time_kernel = std::stoi(argv[7]);
|
||||
const int num_dim_spatial = std::stoi(argv[8]);
|
||||
|
||||
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial
|
||||
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial)
|
||||
// 8 for control, 1 for num_dim_spatial, 4 for G/N/K/C, and 6 * num_dim_spatial, 1 for split-K
|
||||
if(argc != 8 + 1 + 4 + 6 * num_dim_spatial + 1)
|
||||
{
|
||||
print_helper_msg();
|
||||
return 1;
|
||||
@@ -77,6 +77,8 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
|
||||
const auto params = ck::utils::conv::parse_conv_param(num_dim_spatial, 9, argv);
|
||||
|
||||
ck::index_t split_k = std::stoi(argv[8 + 1 + 4 + 6 * num_dim_spatial]);
|
||||
|
||||
using F32 = float;
|
||||
using F16 = ck::half_t;
|
||||
using BF16 = ck::bhalf_t;
|
||||
@@ -110,7 +112,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
|
||||
OutDataType,
|
||||
WeiDataType,
|
||||
InDataType>(
|
||||
do_verification, init_method, do_log, time_kernel, params);
|
||||
do_verification, init_method, do_log, time_kernel, params, split_k);
|
||||
|
||||
return pass ? 0 : 1;
|
||||
};
|
||||
|
||||
@@ -126,6 +126,8 @@ def run_ck_grouped_conv_bwd_data(args):
|
||||
args.ck_profier_op = "grouped_conv_bwd_data"
|
||||
parse_data_type(args)
|
||||
parse_layouts(args)
|
||||
# Test all split K value from the list {1, 2, 4, 8, 32, 64, 128}
|
||||
args.split_k_value = -1
|
||||
|
||||
cmd = [str(args.ck_profiler_cmd), str(args.ck_profier_op)]
|
||||
cmd += [str(args.data_type), str(args.layout)]
|
||||
@@ -136,6 +138,7 @@ def run_ck_grouped_conv_bwd_data(args):
|
||||
cmd += [str(args.in_channels)]
|
||||
add_conv_params_to_cmd(args, cmd)
|
||||
|
||||
cmd += [str(args.split_k_value)]
|
||||
run_ck_profiler_cmd(cmd)
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,11 @@ add_gtest_executable(test_grouped_convnd_bwd_data_xdl test_grouped_convnd_bwd_da
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_xdl PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
endif()
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_executable(test_grouped_convnd_bwd_data_xdl_large_cases test_grouped_convnd_bwd_data_xdl_large_cases.cpp)
|
||||
target_compile_options(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE -Wno-global-constructors -Wno-undef)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_xdl_large_cases PRIVATE gtest_main getopt::getopt utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
endif()
|
||||
add_gtest_executable(test_grouped_convnd_bwd_data_wmma test_grouped_convnd_bwd_data_wmma.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_grouped_convnd_bwd_data_wmma PRIVATE utility device_grouped_conv2d_bwd_data_instance device_grouped_conv3d_bwd_data_instance)
|
||||
|
||||
@@ -21,26 +21,31 @@ class TestGroupedConvndBwdDataXdl : public ::testing::Test
|
||||
using InLayout = std::tuple_element_t<3, Tuple>;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
std::vector<ck::index_t> split_ks{1, 2};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
for(auto& param : conv_params)
|
||||
for(auto split_k : split_ks)
|
||||
{
|
||||
pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl<NDimSpatial,
|
||||
OutLayout,
|
||||
WeiLayout,
|
||||
InLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param);
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl<NDimSpatial,
|
||||
OutLayout,
|
||||
WeiLayout,
|
||||
InLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param,
|
||||
split_k);
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
@@ -92,19 +97,16 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D)
|
||||
this->conv_params.clear();
|
||||
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 4, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
{2, 2, 2, 192, 192, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 128, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
{2, 2, 2, 128, 256, {3, 3}, {14, 14}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 128, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
{2, 2, 2, 128, 256, {1, 1}, {7, 7}, {2, 2}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{2, 2, 128, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
// SplitN case
|
||||
this->conv_params.push_back(
|
||||
{2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}});
|
||||
{2, 2, 2, 128, 256, {1, 1}, {3, 3}, {1, 1}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 32, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 64, 3, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->conv_params.push_back({2, 1, 1, 1, 1, {8, 8}, {16, 16}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
|
||||
this->template Run<2>();
|
||||
}
|
||||
|
||||
@@ -112,28 +114,16 @@ TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 16, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
{3, 2, 2, 128, 256, {1, 1, 1}, {7, 7, 7}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 2, 128, 256, {3, 3, 3}, {14, 14, 3}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 2, 32, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
{3, 2, 2, 128, 256, {1, 1, 1}, {3, 3, 3}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 1, 32, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
{3, 1, 1, 1, 32, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
{3, 1, 1, 64, 3, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->conv_params.push_back(
|
||||
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
// SplitN case
|
||||
this->conv_params.push_back({3,
|
||||
1,
|
||||
128,
|
||||
4,
|
||||
192,
|
||||
{2, 2, 2},
|
||||
{2, 224, 224},
|
||||
{1, 224, 224},
|
||||
{1, 1, 1},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0}});
|
||||
{3, 1, 1, 1, 1, {3, 3, 3}, {4, 16, 16}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
|
||||
this->template Run<3>();
|
||||
}
|
||||
|
||||
@@ -0,0 +1,120 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <initializer_list>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "profiler/profile_grouped_conv_bwd_data_impl.hpp"
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataXdl : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using DataType = std::tuple_element_t<0, Tuple>;
|
||||
using OutLayout = std::tuple_element_t<1, Tuple>;
|
||||
using WeiLayout = std::tuple_element_t<2, Tuple>;
|
||||
using InLayout = std::tuple_element_t<3, Tuple>;
|
||||
|
||||
std::vector<ck::utils::conv::ConvParam> conv_params;
|
||||
std::vector<ck::index_t> split_ks{1, 2};
|
||||
|
||||
template <ck::index_t NDimSpatial>
|
||||
void Run()
|
||||
{
|
||||
EXPECT_FALSE(conv_params.empty());
|
||||
bool pass = true;
|
||||
for(auto split_k : split_ks)
|
||||
{
|
||||
for(auto& param : conv_params)
|
||||
{
|
||||
pass = pass && ck::profiler::profile_grouped_conv_bwd_data_impl<NDimSpatial,
|
||||
OutLayout,
|
||||
WeiLayout,
|
||||
InLayout,
|
||||
DataType,
|
||||
DataType,
|
||||
DataType>(
|
||||
true, // do_verification
|
||||
1, // init_method: integer value
|
||||
false, // do_log
|
||||
false, // time_kernel
|
||||
param,
|
||||
split_k);
|
||||
}
|
||||
}
|
||||
EXPECT_TRUE(pass);
|
||||
}
|
||||
};
|
||||
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using KernelTypes2d = ::testing::Types<std::tuple<float, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<ck::half_t, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<ck::bhalf_t, GNHWK, GKYXC, GNHWC>,
|
||||
std::tuple<float, NGKHW, GKYXC, NGCHW>,
|
||||
std::tuple<ck::half_t, NGKHW, GKYXC, NGCHW>,
|
||||
std::tuple<ck::bhalf_t, NGKHW, GKYXC, NGCHW>,
|
||||
std::tuple<float, NGKHW, GKCYX, NGCHW>,
|
||||
std::tuple<ck::half_t, NGKHW, GKCYX, NGCHW>,
|
||||
std::tuple<ck::bhalf_t, NGKHW, GKCYX, NGCHW>,
|
||||
std::tuple<float, NHWGK, GKYXC, NHWGC>,
|
||||
std::tuple<ck::half_t, NHWGK, GKYXC, NHWGC>,
|
||||
std::tuple<ck::bhalf_t, NHWGK, GKYXC, NHWGC>>;
|
||||
|
||||
using KernelTypes3d = ::testing::Types<std::tuple<float, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<ck::half_t, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<ck::bhalf_t, GNDHWK, GKZYXC, GNDHWC>,
|
||||
std::tuple<float, NGKDHW, GKZYXC, NGCDHW>,
|
||||
std::tuple<ck::half_t, NGKDHW, GKZYXC, NGCDHW>,
|
||||
std::tuple<ck::bhalf_t, NGKDHW, GKZYXC, NGCDHW>,
|
||||
std::tuple<float, NGKDHW, GKCZYX, NGCDHW>,
|
||||
std::tuple<ck::half_t, NGKDHW, GKCZYX, NGCDHW>,
|
||||
std::tuple<ck::bhalf_t, NGKDHW, GKCZYX, NGCDHW>,
|
||||
std::tuple<float, NDHWGK, GKZYXC, NDHWGC>,
|
||||
std::tuple<ck::half_t, NDHWGK, GKZYXC, NDHWGC>,
|
||||
std::tuple<ck::bhalf_t, NDHWGK, GKZYXC, NDHWGC>>;
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataXdl2d : public TestGroupedConvndBwdDataXdl<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
template <typename Tuple>
|
||||
class TestGroupedConvndBwdDataXdl3d : public TestGroupedConvndBwdDataXdl<Tuple>
|
||||
{
|
||||
};
|
||||
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl2d, KernelTypes2d);
|
||||
TYPED_TEST_SUITE(TestGroupedConvndBwdDataXdl3d, KernelTypes3d);
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataXdl2d, Test2D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
// SplitN case
|
||||
this->conv_params.push_back(
|
||||
{2, 1, 128, 4, 192, {2, 2}, {224, 224}, {224, 224}, {1, 1}, {0, 0}, {0, 0}});
|
||||
this->template Run<2>();
|
||||
}
|
||||
|
||||
TYPED_TEST(TestGroupedConvndBwdDataXdl3d, Test3D)
|
||||
{
|
||||
this->conv_params.clear();
|
||||
// SplitN case
|
||||
this->conv_params.push_back({3,
|
||||
1,
|
||||
128,
|
||||
4,
|
||||
192,
|
||||
{2, 2, 2},
|
||||
{2, 224, 224},
|
||||
{1, 224, 224},
|
||||
{1, 1, 1},
|
||||
{0, 0, 0},
|
||||
{0, 0, 0}});
|
||||
this->template Run<3>();
|
||||
}
|
||||
Reference in New Issue
Block a user