Unify conv elementwise ops and layout definitions for fwd and bwd directions.

This commit is contained in:
Ville Pietilä
2026-01-13 02:43:37 -05:00
parent 7e02790293
commit 3e8f3907f2
19 changed files with 120 additions and 143 deletions

View File

@@ -25,7 +25,7 @@ struct ConvBwdWeightDlFactory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =

View File

@@ -28,7 +28,7 @@ struct ConvBwdWeightMultiDWmmaV3Factory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =

View File

@@ -28,7 +28,7 @@ struct ConvBwdWeightMultiDXdlFactory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =

View File

@@ -28,7 +28,7 @@ struct ConvBwdWeightTwoStageWmmaV3Factory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =

View File

@@ -28,7 +28,7 @@ struct ConvBwdWeightTwoStageXdlFactory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =

View File

@@ -28,7 +28,7 @@ struct ConvBwdWeightWmmaFactory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =

View File

@@ -28,7 +28,7 @@ struct ConvBwdWeightWmmaV3Factory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =

View File

@@ -28,7 +28,7 @@ struct ConvBwdWeightXdlFactory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =

View File

@@ -28,7 +28,7 @@ struct ConvBwdWeightXdlV3Factory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::BwdWeightConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BWD_CONV_SPECIALIZATION =

View File

@@ -26,7 +26,7 @@ struct ConvFwdDlFactory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
@@ -94,13 +94,13 @@ struct ConvFwdDlFactory
typename Types::DsDataTypes,
typename Types::EDataType,
typename Types::AccDataType,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Layouts::OutLayout,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
FWD_CONV_SPECIALIZATION,
GEMM_SPECIALIZATION,
BLOCK.block_size,

View File

@@ -28,7 +28,7 @@ struct ConvFwdLargeTensorFactory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
@@ -59,19 +59,19 @@ struct ConvFwdLargeTensorFactory
using Instance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Layouts::OutLayout,
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
ALGORITHM.num_gemm_k_prefetch_stages,

View File

@@ -28,7 +28,7 @@ struct ConvFwdXdlV3Factory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
@@ -64,19 +64,19 @@ struct ConvFwdXdlV3Factory
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Layouts::OutLayout,
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
BLOCK.block_size,

View File

@@ -28,7 +28,7 @@ struct ConvFwdWmmaFactory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
@@ -60,19 +60,19 @@ struct ConvFwdWmmaFactory
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Layouts::OutLayout,
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
ALGORITHM.num_gemm_k_prefetch_stages,

View File

@@ -28,7 +28,7 @@ struct ConvFwdXdlFactory
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ElementwiseOps<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
@@ -59,19 +59,19 @@ struct ConvFwdXdlFactory
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::InLayout,
typename Layouts::WeiLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Layouts::OutLayout,
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
typename Ops::InElementwiseOp,
typename Ops::WeiElementwiseOp,
typename Ops::OutElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
ALGORITHM.num_gemm_k_prefetch_stages,

View File

@@ -62,30 +62,20 @@ consteval auto GetElementwiseOp()
}
template <auto Sig>
struct ElementwiseOps
struct ConvElementwiseOps
{
private:
static constexpr auto input_op = GetElementwiseOp<Sig.input>();
static constexpr auto weight_op = GetElementwiseOp<Sig.weight>();
static constexpr auto output_op = GetElementwiseOp<Sig.output>();
static constexpr bool is_forward = ConvDirectionIsForward<Sig>;
static constexpr bool is_bwd_weight = ConvDirectionIsBackwardWeight<Sig>;
using InputOp = typename decltype(input_op)::Op;
using WeightOp = typename decltype(weight_op)::Op;
using OutputOp = typename decltype(output_op)::Op;
public:
// Forward convolution elementwise ops
using AElementwiseOp = std::conditional_t<is_forward, InputOp, void>;
using BElementwiseOp = std::conditional_t<is_forward, WeightOp, void>;
using CDEElementwiseOp = std::conditional_t<is_forward, OutputOp, void>;
// Backward weight convolution elementwise ops
using InElementwiseOp = std::conditional_t<is_bwd_weight, InputOp, void>;
using WeiElementwiseOp = std::conditional_t<is_bwd_weight, WeightOp, void>;
using OutElementwiseOp = std::conditional_t<is_bwd_weight, OutputOp, void>;
using InElementwiseOp = typename decltype(input_op)::Op;
using WeiElementwiseOp = typename decltype(weight_op)::Op;
using OutElementwiseOp = typename decltype(output_op)::Op;
// TODO: Remove, now left for compatibility. Factories do not need it anymore.
// using AElementwiseOp = InElementwiseOp;
// using BElementwiseOp = WeiElementwiseOp;
// using CDEElementwiseOp = OutElementwiseOp;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -222,28 +222,15 @@ template <auto Signature, size_t SPATIAL_DIM>
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
struct ConvTensorLayouts
{
private:
static constexpr bool is_forward = ConvDirectionIsForward<Signature>;
static constexpr bool is_bwd_weight = ConvDirectionIsBackwardWeight<Signature>;
using InLayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
using WeiLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
using OutLayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM>())::type;
using InputLayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
using WeightLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
using OutputLayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
using AuxLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM>())::type;
public:
// Forward convolution layouts
using ALayout = std::conditional_t<is_forward, InputLayout, void>;
using BLayout = std::conditional_t<is_forward, WeightLayout, void>;
using ELayout = std::conditional_t<is_forward, OutputLayout, void>;
// Backward weight convolution layouts
using InLayout = std::conditional_t<is_bwd_weight, InputLayout, void>;
using WeiLayout = std::conditional_t<is_bwd_weight, WeightLayout, void>;
using OutLayout = std::conditional_t<is_bwd_weight, OutputLayout, void>;
// Applicable for all directions
using DsLayout = AuxLayout;
// TODO: Remove,now left for compatibility. Factories do not need it anymore.
// using ALayout = InLayout;
// using BLayout = WeiLayout;
// using ELayout = OutLayout;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -71,7 +71,7 @@ struct Args<SIGNATURE>
using OutputDescriptor = TensorDescriptor<OUTPUT_TYPE, OUTPUT_RANK>;
// TODO: We shouldn't need to call into an internal namespace here.
using Ops = factory::internal::ElementwiseOps<SIGNATURE>;
using Ops = factory::internal::ConvElementwiseOps<SIGNATURE>;
// TODO: We shouldn't need to call into an internal namespace here.
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
@@ -88,9 +88,9 @@ struct Args<SIGNATURE>
FilterExtent<SPATIAL_DIM> input_left_pad;
FilterExtent<SPATIAL_DIM> input_right_pad;
Ops::AElementwiseOp a_elementwise_op;
Ops::BElementwiseOp b_elementwise_op;
Ops::CDEElementwiseOp cde_elementwise_op;
Ops::InElementwiseOp a_elementwise_op;
Ops::WeiElementwiseOp b_elementwise_op;
Ops::OutElementwiseOp cde_elementwise_op;
/// This function returns the `TensorDescriptor` corresponding to
/// the input-tensor of the convolution problem. This can then
@@ -105,7 +105,7 @@ struct Args<SIGNATURE>
// function.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
typename Layouts::ALayout>(param);
typename Layouts::InLayout>(param);
using Extent = typename InputDescriptor::Extent;
return InputDescriptor(Extent::from_vector(desc.GetLengths()),
Extent::from_vector(desc.GetStrides()));
@@ -119,7 +119,7 @@ struct Args<SIGNATURE>
// See note in implementation of `make_input_descriptor`.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<
typename Layouts::BLayout>(param);
typename Layouts::WeiLayout>(param);
using Extent = typename WeightDescriptor::Extent;
return WeightDescriptor(Extent::from_vector(desc.GetLengths()),
Extent::from_vector(desc.GetStrides()));
@@ -133,7 +133,7 @@ struct Args<SIGNATURE>
// See note in implementation of `make_input_descriptor`.
const auto param = to_ck_conv_param();
const auto desc = ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<
typename Layouts::ELayout>(param);
typename Layouts::OutLayout>(param);
using Extent = typename OutputDescriptor::Extent;
return OutputDescriptor(Extent::from_vector(desc.GetLengths()),
Extent::from_vector(desc.GetStrides()));

View File

@@ -27,7 +27,7 @@ template <typename Conv,
auto SIGNATURE,
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
// TODO: We shouldn't need to call into an internal namespace here.
typename Ops = factory::internal::ElementwiseOps<SIGNATURE>>
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
concept CkConvInstance = requires(Conv& conv,
// TODO: This should be changed depending on IsMultiA etc.
// Currently that is not yet supported elsewhere anyway.
@@ -37,9 +37,9 @@ concept CkConvInstance = requires(Conv& conv,
std::array<index_t, SPATIAL_DIM + 3> lengths,
std::array<index_t, SPATIAL_DIM + 3> strides,
std::array<index_t, SPATIAL_DIM> filter,
Ops::AElementwiseOp elementwise_a,
Ops::BElementwiseOp elementwise_b,
Ops::CDEElementwiseOp elementwise_cde) {
Ops::InElementwiseOp elementwise_a,
Ops::WeiElementwiseOp elementwise_b,
Ops::OutElementwiseOp elementwise_cde) {
{
conv.MakeArgument(p_a,
p_b,

View File

@@ -40,9 +40,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NWGC_GKXC_NWGK)
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -59,9 +59,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKXC_NGKW)
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -78,9 +78,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_GNWC_GKXC_GNWK)
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -97,9 +97,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor1D_NGCW_GKCX_NGKW)
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -116,9 +116,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKYXC_NGKHW)
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -135,9 +135,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NHWGC_GKYXC_NHWGK)
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -154,9 +154,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_GNHWC_GKYXC_GNHWK)
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -173,9 +173,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor2D_NGCHW_GKCYX_NGKHW)
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -192,9 +192,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NGCDHW_GKCZYX_NGKDHW)
using TensorLayouts = ConvTensorLayouts<sig, 3>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKCZYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKCZYX>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKDHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -211,9 +211,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_NDHWGC_GKZYXC_NDHWGK)
using TensorLayouts = ConvTensorLayouts<sig, 3>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NDHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -230,9 +230,9 @@ TEST(ConvTensorLayout, AssignsLayoutsFor3D_GNDHWC_GKZYXC_GNDHWK)
using TensorLayouts = ConvTensorLayouts<sig, 3>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNDHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNDHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNDHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNDHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ck::Tuple<>>));
}
@@ -389,9 +389,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasG_K)
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NGKHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NGCHW>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NGKHW>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
@@ -416,9 +416,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithSingleBiasGC)
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NHWGK>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::GC>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
@@ -444,9 +444,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv2DWithTwoAuxiliaryTensors)
using TensorLayouts = ConvTensorLayouts<sig, 2>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::GNHWK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::GNHWC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::GNHWK>));
using ExpectedDsLayout =
ck::Tuple<ck::tensor_layout::convolution::G_K, ck::tensor_layout::convolution::GC>;
@@ -472,9 +472,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv1DWithBias)
using TensorLayouts = ConvTensorLayouts<sig, 1>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NWGK>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_K>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));
@@ -499,9 +499,9 @@ TEST(ConvTensorLayoutsWithAuxiliary, Conv3DWithBias)
using TensorLayouts = ConvTensorLayouts<sig, 3>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::ALayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::BLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::ELayout, ck::tensor_layout::convolution::NDHWGK>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::InLayout, ck::tensor_layout::convolution::NDHWGC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::WeiLayout, ck::tensor_layout::convolution::GKZYXC>));
EXPECT_TRUE((std::is_same_v<TensorLayouts::OutLayout, ck::tensor_layout::convolution::NDHWGK>));
using ExpectedDsLayout = ck::Tuple<ck::tensor_layout::convolution::G_C>;
EXPECT_TRUE((std::is_same_v<TensorLayouts::DsLayout, ExpectedDsLayout>));