mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Add device operation to conv signature. Use unions to hold conv layouts and device operations.
This commit is contained in:
@@ -158,6 +158,28 @@ struct ConvTensorLayouts<GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, ConvDirecti
|
||||
using ELayout = ck::tensor_layout::convolution::GNDHWK;
|
||||
};
|
||||
|
||||
template <GroupConvLayout Layout, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
consteval auto GetTensorLayout()
|
||||
{
|
||||
|
||||
if constexpr(SPATIAL_DIM == 1)
|
||||
{
|
||||
return factory_internal::ConvTensorLayouts<Layout._1d, 1, DIR>{};
|
||||
}
|
||||
else if constexpr(SPATIAL_DIM == 2)
|
||||
{
|
||||
return factory_internal::ConvTensorLayouts<Layout._2d, 2, DIR>{};
|
||||
}
|
||||
else if constexpr(SPATIAL_DIM == 3)
|
||||
{
|
||||
return factory_internal::ConvTensorLayouts<Layout._3d, 3, DIR>{};
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unsupported spatial dimension for convolution layout.");
|
||||
}
|
||||
}
|
||||
|
||||
// Type mappings from builder convolution data type to CK tensor types.
|
||||
template <DataType T>
|
||||
struct ConvTensorTypes
|
||||
@@ -440,8 +462,8 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts =
|
||||
factory_internal::ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
/*static constexpr auto*/
|
||||
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
// signature at compile time.
|
||||
#pragma once
|
||||
|
||||
#include <variant>
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
@@ -40,16 +41,21 @@ template <DataType T>
|
||||
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
|
||||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
|
||||
|
||||
template <typename T>
|
||||
concept ConvDeviceOp = std::same_as<std::remove_cvref_t<T>, GroupConvDeviceOp>;
|
||||
|
||||
template <typename T>
|
||||
concept ConvLayout = std::same_as<std::remove_cvref_t<T>, GroupConvLayout>;
|
||||
|
||||
// Concept for a type that defines a convolution's operational signature.
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
|
||||
{ t.direction } -> std::convertible_to<ConvDirection>;
|
||||
requires std::convertible_to<decltype(t.layout), GroupConvLayout1D> ||
|
||||
std::convertible_to<decltype(t.layout), GroupConvLayout2D> ||
|
||||
std::convertible_to<decltype(t.layout), GroupConvLayout3D>;
|
||||
{ t.layout } -> ConvLayout;
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
{ t.device_operation } -> ConvDeviceOp;
|
||||
};
|
||||
|
||||
// Concept to validate a convolution signature's values.
|
||||
|
||||
@@ -48,6 +48,18 @@ enum class GroupConvLayout3D
|
||||
NGCDHW_GKCZYX_NGKDHW,
|
||||
};
|
||||
|
||||
struct GroupConvLayout {
|
||||
union {
|
||||
GroupConvLayout1D _1d;
|
||||
GroupConvLayout2D _2d;
|
||||
GroupConvLayout3D _3d;
|
||||
};
|
||||
|
||||
constexpr GroupConvLayout(GroupConvLayout1D layout) : _1d(layout) {}
|
||||
constexpr GroupConvLayout(GroupConvLayout2D layout) : _2d(layout) {}
|
||||
constexpr GroupConvLayout(GroupConvLayout3D layout) : _3d(layout) {}
|
||||
};
|
||||
|
||||
// Direction of the convolution operation.
|
||||
enum class ConvDirection
|
||||
{
|
||||
@@ -56,6 +68,50 @@ enum class ConvDirection
|
||||
BACKWARD_WEIGHT
|
||||
};
|
||||
|
||||
// Forward convolution device operations.
|
||||
enum class FwdGroupConvDeviceOperation
|
||||
{
|
||||
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,
|
||||
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3,
|
||||
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
|
||||
};
|
||||
|
||||
// Backward data convolution device operations.
|
||||
enum class BwdDataGroupConvDeviceOperation
|
||||
{
|
||||
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1,
|
||||
DeviceGroupedConvBwdDataMultipleD,
|
||||
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle
|
||||
};
|
||||
|
||||
// Backward weight convolution device operations.
|
||||
enum class BwdWeightGroupConvDeviceOperation
|
||||
{
|
||||
DeviceGroupedConvBwdWeight,
|
||||
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffle,
|
||||
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle,
|
||||
DeviceGroupedConvBwdWeight_Wmma_CShuffle,
|
||||
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3,
|
||||
DeviceGroupedConvBwdWeightMultipleD,
|
||||
DeviceGroupedConvBwdWeight_Dl
|
||||
};
|
||||
|
||||
// Structural type for device operation
|
||||
struct GroupConvDeviceOp {
|
||||
union {
|
||||
FwdGroupConvDeviceOperation _fwd;
|
||||
BwdDataGroupConvDeviceOperation _bwd_data;
|
||||
BwdWeightGroupConvDeviceOperation _bwd_weight;
|
||||
};
|
||||
|
||||
constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {}
|
||||
constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {}
|
||||
constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {}
|
||||
};
|
||||
|
||||
// Fused element-wise operations.
|
||||
enum class ElementwiseOperation
|
||||
{
|
||||
|
||||
@@ -9,12 +9,14 @@ namespace ck_tile::builder::testing {
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout1D> FwdConvSignature{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 1,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout1D::NGCW_GKXC_NGKW,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::SCALE};
|
||||
.elementwise_operation = ElementwiseOperation::SCALE,
|
||||
.device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
@@ -8,12 +8,14 @@ namespace ck_tile::builder::testing {
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
@@ -28,12 +30,14 @@ TEST(FwdConvInstances,
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
@@ -7,12 +7,14 @@ namespace ck_tile::builder::testing {
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
@@ -7,12 +7,14 @@ namespace ck_tile::builder::testing {
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
@@ -8,12 +8,14 @@ namespace ck_tile::builder::testing {
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
@@ -8,12 +8,14 @@ namespace ck_tile::builder::testing {
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
@@ -8,12 +8,14 @@ namespace ck_tile::builder::testing {
|
||||
TEST(FwdConvInstances,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
constexpr ConvSignature FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
|
||||
.device_operation = FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
@@ -3,11 +3,13 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <variant>
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
template <typename GroupConvLayout>
|
||||
using namespace ck_tile::builder;
|
||||
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim;
|
||||
@@ -15,9 +17,8 @@ struct ConvSignature
|
||||
GroupConvLayout layout;
|
||||
DataType data_type;
|
||||
ElementwiseOperation elementwise_operation;
|
||||
GroupConvDeviceOp device_operation;
|
||||
};
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout1D>>);
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout2D>>);
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout3D>>);
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -11,7 +11,7 @@ using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
// Common test implementation
|
||||
template <auto FwdConvSignature,
|
||||
template <ConvSignature FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
BlockGemmPipelineVersion FwdPipelineVersion,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
|
||||
Reference in New Issue
Block a user