Rename layouts to channels first or channels last.

Also add a spatial dimension parameter to layout mapping in the factory to handle different layout enums for 2D and 3D convolutions
This commit is contained in:
John Shumway
2025-09-14 14:03:35 +00:00
parent 63fc483c1e
commit fe6140353f
6 changed files with 19 additions and 14 deletions

View File

@@ -12,11 +12,18 @@
namespace ck_tile::builder {
// Type mappings from the builder GroupConvLayout enum class to the CK tensor data types.
template <GroupConvLayout Layout>
struct ConvTensorLayouts;
template <GroupConvLayout Layout, int SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM>)
struct ConvTensorLayouts
{
// This will trigger if a specialization for the given layout is not found.
// We should always catch this in an earlier validation check.
static_assert(sizeof(Layout) == 0,
"Internal error. Unsupported layout for convolution factory.");
};
template <>
struct ConvTensorLayouts<GroupConvLayout::NGCHW_GKCYX_NGKHW>
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_FIRST, 2>
{
// Channels first convolution layout.
using ALayout = ck::tensor_layout::convolution::NHWGC;
@@ -26,7 +33,7 @@ struct ConvTensorLayouts<GroupConvLayout::NGCHW_GKCYX_NGKHW>
};
template <>
struct ConvTensorLayouts<GroupConvLayout::NHWGC_GKYXC_NHWGK>
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 2>
{
// Channels last convolution layout.
using ALayout = ck::tensor_layout::convolution::NHWGC;
@@ -36,7 +43,7 @@ struct ConvTensorLayouts<GroupConvLayout::NHWGC_GKYXC_NHWGK>
};
template <>
struct ConvTensorLayouts<GroupConvLayout::NDHWGC_GKZYXC_NDHWGK>
struct ConvTensorLayouts<GroupConvLayout::CHANNELS_LAST, 3>
{
// Channels last convolution layout.
using ALayout = ck::tensor_layout::convolution::NDHWGC;
@@ -45,7 +52,6 @@ struct ConvTensorLayouts<GroupConvLayout::NDHWGC_GKZYXC_NDHWGK>
using ELayout = ck::tensor_layout::convolution::NDHWGK;
};
// Type mappings from builder convolution data type to CK tensor types.
template <DataType T>
struct ConvTensorTypes
@@ -283,7 +289,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFactory
{
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = ConvTensorLayouts<SIGNATURE.layout>;
using Layouts = ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM>;
using Types = ConvTensorTypes<SIGNATURE.data_type>;
using Ops = ConvPassThroughOps;
static constexpr ConvSpec SPECIALIZATION{

View File

@@ -10,9 +10,8 @@ namespace ck_tile::builder {
// Layouts for grouped convolutions.
enum class GroupConvLayout
{
NHWGC_GKYXC_NHWGK, // Channels-last
NDHWGC_GKZYXC_NDHWGK, // Channels-last
NGCHW_GKCYX_NGKHW // Channels-first
CHANNELS_LAST, // Channels-last NHWGC_GKYXC_NHWGK
CHANNELS_FIRST // Channels-first NGCHW_GKCYX_NGKHW
};
// Spatial dimensionalities of grouped convolutions.

View File

@@ -13,7 +13,7 @@ struct ConvSignature
{
int spatial_dim = 2;
ckb::ConvDirection direction = ckb::ConvDirection::Forward;
ckb::GroupConvLayout layout = ckb::GroupConvLayout::NHWGC_GKYXC_NHWGK;
ckb::GroupConvLayout layout = ckb::GroupConvLayout::CHANNELS_LAST;
ckb::DataType data_type = ckb::DataType::FP16;
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);

View File

@@ -13,7 +13,7 @@ struct ConvSignature
{
int spatial_dim = 2;
ckb::ConvDirection direction = ckb::ConvDirection::Forward;
ckb::GroupConvLayout layout = ckb::GroupConvLayout::NHWGC_GKYXC_NHWGK;
ckb::GroupConvLayout layout = ckb::GroupConvLayout::CHANNELS_LAST;
ckb::DataType data_type = ckb::DataType::FP16;
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);

View File

@@ -13,7 +13,7 @@ struct ConvSignature
{
int spatial_dim = 3;
ckb::ConvDirection direction = ckb::ConvDirection::Forward;
ckb::GroupConvLayout layout = ckb::GroupConvLayout::NDHWGC_GKZYXC_NDHWGK;
ckb::GroupConvLayout layout = ckb::GroupConvLayout::CHANNELS_LAST;
ckb::DataType data_type = ckb::DataType::BF16;
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);

View File

@@ -19,7 +19,7 @@ struct ConvSignature
{
int spatial_dim = 2;
ckb::ConvDirection direction = ckb::ConvDirection::Forward;
ckb::GroupConvLayout layout = ckb::GroupConvLayout::NHWGC_GKYXC_NHWGK;
ckb::GroupConvLayout layout = ckb::GroupConvLayout::CHANNELS_LAST;
ckb::DataType data_type = ckb::DataType::FP16;
};
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);