mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Unify conv elementwise ops and layout definitions for fwd and bwd directions.
This commit is contained in:
@@ -25,7 +25,7 @@ struct ConvBwdWeightDlFactory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
|
||||
@@ -28,7 +28,7 @@ struct ConvBwdWeightMultiDWmmaV3Factory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
|
||||
@@ -28,7 +28,7 @@ struct ConvBwdWeightMultiDXdlFactory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
|
||||
@@ -28,7 +28,7 @@ struct ConvBwdWeightTwoStageWmmaV3Factory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
|
||||
@@ -28,7 +28,7 @@ struct ConvBwdWeightTwoStageXdlFactory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
|
||||
@@ -28,7 +28,7 @@ struct ConvBwdWeightWmmaFactory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
|
||||
@@ -28,7 +28,7 @@ struct ConvBwdWeightWmmaV3Factory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
|
||||
@@ -28,7 +28,7 @@ struct ConvBwdWeightXdlFactory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
|
||||
@@ -28,7 +28,7 @@ struct ConvBwdWeightXdlV3Factory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BWD_CONV_SPECIALIZATION =
|
||||
|
||||
@@ -26,7 +26,7 @@ struct ConvFwdDlFactory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
@@ -94,13 +94,13 @@ struct ConvFwdDlFactory
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Layouts::OutLayout,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
FWD_CONV_SPECIALIZATION,
|
||||
GEMM_SPECIALIZATION,
|
||||
BLOCK.block_size,
|
||||
|
||||
@@ -28,7 +28,7 @@ struct ConvFwdLargeTensorFactory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
@@ -59,19 +59,19 @@ struct ConvFwdLargeTensorFactory
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
|
||||
@@ -28,7 +28,7 @@ struct ConvFwdXdlV3Factory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
|
||||
@@ -64,19 +64,19 @@ struct ConvFwdXdlV3Factory
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BLOCK.block_size,
|
||||
|
||||
@@ -28,7 +28,7 @@ struct ConvFwdWmmaFactory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
@@ -60,19 +60,19 @@ struct ConvFwdWmmaFactory
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
|
||||
@@ -28,7 +28,7 @@ struct ConvFwdXdlFactory
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
@@ -59,19 +59,19 @@ struct ConvFwdXdlFactory
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::InLayout,
|
||||
typename Layouts::WeiLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Layouts::OutLayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
typename Ops::InElementwiseOp,
|
||||
typename Ops::WeiElementwiseOp,
|
||||
typename Ops::OutElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
ALGORITHM.num_gemm_k_prefetch_stages,
|
||||
|
||||
@@ -62,30 +62,20 @@ consteval auto GetElementwiseOp()
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
struct ElementwiseOps
|
||||
struct ConvElementwiseOps
|
||||
{
|
||||
private:
|
||||
static constexpr auto input_op = GetElementwiseOp<Sig.input>();
|
||||
static constexpr auto weight_op = GetElementwiseOp<Sig.weight>();
|
||||
static constexpr auto output_op = GetElementwiseOp<Sig.output>();
|
||||
|
||||
static constexpr bool is_forward = ConvDirectionIsForward<Sig>;
|
||||
static constexpr bool is_bwd_weight = ConvDirectionIsBackwardWeight<Sig>;
|
||||
|
||||
using InputOp = typename decltype(input_op)::Op;
|
||||
using WeightOp = typename decltype(weight_op)::Op;
|
||||
using OutputOp = typename decltype(output_op)::Op;
|
||||
|
||||
public:
|
||||
// Forward convolution elementwise ops
|
||||
using AElementwiseOp = std::conditional_t<is_forward, InputOp, void>;
|
||||
using BElementwiseOp = std::conditional_t<is_forward, WeightOp, void>;
|
||||
using CDEElementwiseOp = std::conditional_t<is_forward, OutputOp, void>;
|
||||
|
||||
// Backward weight convolution elementwise ops
|
||||
using InElementwiseOp = std::conditional_t<is_bwd_weight, InputOp, void>;
|
||||
using WeiElementwiseOp = std::conditional_t<is_bwd_weight, WeightOp, void>;
|
||||
using OutElementwiseOp = std::conditional_t<is_bwd_weight, OutputOp, void>;
|
||||
using InElementwiseOp = typename decltype(input_op)::Op;
|
||||
using WeiElementwiseOp = typename decltype(weight_op)::Op;
|
||||
using OutElementwiseOp = typename decltype(output_op)::Op;
|
||||
|
||||
// TODO: Remove, now left for compatibility. Factories do not need it anymore.
|
||||
// using AElementwiseOp = InElementwiseOp;
|
||||
// using BElementwiseOp = WeiElementwiseOp;
|
||||
// using CDEElementwiseOp = OutElementwiseOp;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -222,28 +222,15 @@ template <auto Signature, size_t SPATIAL_DIM>
|
||||
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
private:
|
||||
static constexpr bool is_forward = ConvDirectionIsForward<Signature>;
|
||||
static constexpr bool is_bwd_weight = ConvDirectionIsBackwardWeight<Signature>;
|
||||
using InLayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
|
||||
using WeiLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
|
||||
using OutLayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM>())::type;
|
||||
|
||||
using InputLayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
|
||||
using WeightLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
|
||||
using OutputLayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
|
||||
using AuxLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM>())::type;
|
||||
|
||||
public:
|
||||
// Forward convolution layouts
|
||||
using ALayout = std::conditional_t<is_forward, InputLayout, void>;
|
||||
using BLayout = std::conditional_t<is_forward, WeightLayout, void>;
|
||||
using ELayout = std::conditional_t<is_forward, OutputLayout, void>;
|
||||
|
||||
// Backward weight convolution layouts
|
||||
using InLayout = std::conditional_t<is_bwd_weight, InputLayout, void>;
|
||||
using WeiLayout = std::conditional_t<is_bwd_weight, WeightLayout, void>;
|
||||
using OutLayout = std::conditional_t<is_bwd_weight, OutputLayout, void>;
|
||||
|
||||
// Applicable for all directions
|
||||
using DsLayout = AuxLayout;
|
||||
// TODO: Remove,now left for compatibility. Factories do not need it anymore.
|
||||
// using ALayout = InLayout;
|
||||
// using BLayout = WeiLayout;
|
||||
// using ELayout = OutLayout;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -71,7 +71,7 @@ struct Args<SIGNATURE>
|
||||
using OutputDescriptor = TensorDescriptor<OUTPUT_TYPE, OUTPUT_RANK>;
|
||||
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
using Ops = factory::internal::ElementwiseOps<SIGNATURE>;
|
||||
using Ops = factory::internal::ConvElementwiseOps<SIGNATURE>;
|
||||
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
@@ -88,9 +88,9 @@ struct Args<SIGNATURE>
|
||||
FilterExtent<SPATIAL_DIM> input_left_pad;
|
||||
FilterExtent<SPATIAL_DIM> input_right_pad;
|
||||
|
||||
Ops::AElementwiseOp a_elementwise_op;
|
||||
Ops::BElementwiseOp b_elementwise_op;
|
||||
Ops::CDEElementwiseOp cde_elementwise_op;
|
||||
Ops::InElementwiseOp a_elementwise_op;
|
||||
Ops::WeiElementwiseOp b_elementwise_op;
|
||||
Ops::OutElementwiseOp cde_elementwise_op;
|
||||
|
||||
/// This function returns the `TensorDescriptor` corresponding to
|
||||
/// the input-tensor of the convolution problem. This can then
|
||||
@@ -105,7 +105,7 @@ struct Args<SIGNATURE>
|
||||
// function.
|
||||
const auto param = to_ck_conv_param();
|
||||
const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
|
||||
typename Layouts::ALayout>(param);
|
||||
typename Layouts::InLayout>(param);
|
||||
using Extent = typename InputDescriptor::Extent;
|
||||
return InputDescriptor(Extent::from_vector(desc.GetLengths()),
|
||||
Extent::from_vector(desc.GetStrides()));
|
||||
@@ -119,7 +119,7 @@ struct Args<SIGNATURE>
|
||||
// See note in implementation of `make_input_descriptor`.
|
||||
const auto param = to_ck_conv_param();
|
||||
const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<
|
||||
typename Layouts::BLayout>(param);
|
||||
typename Layouts::WeiLayout>(param);
|
||||
using Extent = typename WeightDescriptor::Extent;
|
||||
return WeightDescriptor(Extent::from_vector(desc.GetLengths()),
|
||||
Extent::from_vector(desc.GetStrides()));
|
||||
@@ -133,7 +133,7 @@ struct Args<SIGNATURE>
|
||||
// See note in implementation of `make_input_descriptor`.
|
||||
const auto param = to_ck_conv_param();
|
||||
const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<
|
||||
typename Layouts::ELayout>(param);
|
||||
typename Layouts::OutLayout>(param);
|
||||
using Extent = typename OutputDescriptor::Extent;
|
||||
return OutputDescriptor(Extent::from_vector(desc.GetLengths()),
|
||||
Extent::from_vector(desc.GetStrides()));
|
||||
|
||||
@@ -27,7 +27,7 @@ template <typename Conv,
|
||||
auto SIGNATURE,
|
||||
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
typename Ops = factory::internal::ElementwiseOps<SIGNATURE>>
|
||||
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
|
||||
concept CkConvInstance = requires(Conv& conv,
|
||||
// TODO: This should be changed depending on IsMultiA etc.
|
||||
// Currently that is not yet supported elsewhere anyway.
|
||||
@@ -37,9 +37,9 @@ concept CkConvInstance = requires(Conv& conv,
|
||||
std::array<index_t, SPATIAL_DIM + 3> lengths,
|
||||
std::array<index_t, SPATIAL_DIM + 3> strides,
|
||||
std::array<index_t, SPATIAL_DIM> filter,
|
||||
Ops::AElementwiseOp elementwise_a,
|
||||
Ops::BElementwiseOp elementwise_b,
|
||||
Ops::CDEElementwiseOp elementwise_cde) {
|
||||
Ops::InElementwiseOp elementwise_a,
|
||||
Ops::WeiElementwiseOp elementwise_b,
|
||||
Ops::OutElementwiseOp elementwise_cde) {
|
||||
{
|
||||
conv.MakeArgument(p_a,
|
||||
p_b,
|
||||
|
||||
@@ -40,9 +40,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -59,9 +59,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -78,9 +78,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -97,9 +97,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -116,9 +116,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -135,9 +135,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -154,9 +154,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -173,9 +173,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -192,9 +192,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCZYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCZYX>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKDHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -211,9 +211,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -230,9 +230,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNDHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNDHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNDHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNDHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
|
||||
}
|
||||
|
||||
@@ -389,9 +389,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
@@ -416,9 +416,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::GC>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
@@ -444,9 +444,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 2>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNHWC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNHWK>));
|
||||
|
||||
using ExpectedDsLayout =
|
||||
ck::Tuple<ck::tensor_layout::convolution::G_K, ck::tensor_layout::convolution::GC>;
|
||||
@@ -472,9 +472,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 1>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NWGK>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
@@ -499,9 +499,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias)
|
||||
|
||||
using TensorLayouts = ConvTensorLayouts<sig, 3>;
|
||||
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NDHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NDHWGK>));
|
||||
|
||||
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_C>;
|
||||
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
|
||||
|
||||
Reference in New Issue
Block a user