mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 04:07:56 +00:00
Rename layouts to channels first or channels last.
Also add a spatial dimension parameter to layout mapping in the factory to handle different layout enums for 2D and 3D convolutions
This commit is contained in:
@@ -12,11 +12,18 @@
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Type mappings from the builder GroupConvLayout enum class to the CK tensor data types.
|
||||
template <GroupConvLayout Layout>
|
||||
struct ConvTensorLayouts;
|
||||
template <GroupConvLayout Layout, int SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
// This will trigger if a specialization for the given layout is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(Layout) == 0,
|
||||
"Internal error. Unsupported layout for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout::NGCHW_GKCYX_NGKHW>
|
||||
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_FIRST, 2>
|
||||
{
|
||||
// Channels first convolution layout.
|
||||
using ALayout = ck::tensor_layout::convolution::NHWGC;
|
||||
@@ -26,7 +33,7 @@ struct ConvTensorLayouts<GroupConvLayout::NGCHW_GKCYX_NGKHW>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout::NHWGC_GKYXC_NHWGK>
|
||||
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 2>
|
||||
{
|
||||
// Channels last convolution layout.
|
||||
using ALayout = ck::tensor_layout::convolution::NHWGC;
|
||||
@@ -36,7 +43,7 @@ struct ConvTensorLayouts<GroupConvLayout::NHWGC_GKYXC_NHWGK>
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout::NDHWGC_GKZYXC_NDHWGK>
|
||||
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 3>
|
||||
{
|
||||
// Channels last convolution layout.
|
||||
using ALayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
@@ -45,7 +52,6 @@ struct ConvTensorLayouts<GroupConvLayout::NDHWGC_GKZYXC_NDHWGK>
|
||||
using ELayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
};
|
||||
|
||||
|
||||
// Type mappings from builder convolution data type to CK tensor types.
|
||||
template <DataType T>
|
||||
struct ConvTensorTypes
|
||||
@@ -283,7 +289,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFactory
|
||||
{
|
||||
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = ConvTensorLayouts<SIGNATURE.layout>;
|
||||
using Layouts = ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM>;
|
||||
using Types = ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = ConvPassThroughOps;
|
||||
static constexpr ConvSpec SPECIALIZATION{
|
||||
|
||||
@@ -10,9 +10,8 @@ namespace ck_tile::builder {
|
||||
// Layouts for grouped convolutions.
|
||||
enum class GroupConvLayout
|
||||
{
|
||||
NHWGC_GKYXC_NHWGK, // Channels-last
|
||||
NDHWGC_GKZYXC_NDHWGK, // Channels-last
|
||||
NGCHW_GKCYX_NGKHW // Channels-first
|
||||
CHANNELS_LAST, // Channels-last NHWGC_GKYXC_NHWGK
|
||||
CHANNELS_FIRST // Channels-first NGCHW_GKCYX_NGKHW
|
||||
};
|
||||
|
||||
// Spatial dimensionalities of grouped convolutions.
|
||||
|
||||
@@ -13,7 +13,7 @@ struct ConvSignature
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::Forward;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout::NHWGC_GKYXC_NHWGK;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout::CHANNELS_LAST;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
@@ -13,7 +13,7 @@ struct ConvSignature
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::Forward;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout::NHWGC_GKYXC_NHWGK;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout::CHANNELS_LAST;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
@@ -13,7 +13,7 @@ struct ConvSignature
|
||||
{
|
||||
int spatial_dim = 3;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::Forward;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout::NDHWGC_GKZYXC_NDHWGK;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout::CHANNELS_LAST;
|
||||
ckb::DataType data_type = ckb::DataType::BF16;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
@@ -19,7 +19,7 @@ struct ConvSignature
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::Forward;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout::NHWGC_GKYXC_NHWGK;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout::CHANNELS_LAST;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
Reference in New Issue
Block a user