Merge commit 'b387249fd905b595f2d38ac2a18d8c2aa9b88c00' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-31 00:35:07 +00:00
parent ab856a3e02
commit c41df57bad
14 changed files with 318 additions and 43 deletions

View File

@@ -158,6 +158,28 @@ struct ConvTensorLayouts<GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, ConvDirecti
using ELayout = ck::tensor_layout::convolution::GNDHWK;
};
template <GroupConvLayout Layout, size_t SPATIAL_DIM, ConvDirection DIR>
consteval auto GetTensorLayout()
{
if constexpr(SPATIAL_DIM == 1)
{
return factory_internal::ConvTensorLayouts<Layout._1d, 1, DIR>{};
}
else if constexpr(SPATIAL_DIM == 2)
{
return factory_internal::ConvTensorLayouts<Layout._2d, 2, DIR>{};
}
else if constexpr(SPATIAL_DIM == 3)
{
return factory_internal::ConvTensorLayouts<Layout._3d, 3, DIR>{};
}
else
{
static_assert(false, "Unsupported spatial dimension for convolution layout.");
}
}
// Type mappings from builder convolution data type to CK tensor types.
template <DataType T>
struct ConvTensorTypes
@@ -432,16 +454,19 @@ template <ConvSignatureDescriptor auto SIGNATURE,
auto VERSION>
struct ConvFactory;
// Factory specialization for an instance of a grouped forward convolution kernel.
// Factory specialization for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 instance
// of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE>
requires ConvDirectionIsForward<SIGNATURE> &&
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<SIGNATURE>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts =
factory_internal::ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
using Layouts = decltype(factory_internal::GetTensorLayout<SIGNATURE.layout,
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -21,6 +21,7 @@
#include <type_traits>
#include "ck_tile/builder/types.hpp"
#include "ck_tile/builder/conv_signature_predicates.hpp"
namespace ck_tile::builder {
@@ -40,16 +41,21 @@ template <DataType T>
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
template <typename T>
concept ConvDeviceOp = std::same_as<std::remove_cvref_t<T>, GroupConvDeviceOp>;
template <typename T>
concept ConvLayout = std::same_as<std::remove_cvref_t<T>, GroupConvLayout>;
// Concept for a type that defines a convolution's operational signature.
template <typename T>
concept ConvSignatureDescriptor = requires(T t) {
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
{ t.direction } -> std::convertible_to<ConvDirection>;
requires std::convertible_to<decltype(t.layout), GroupConvLayout1D> ||
std::convertible_to<decltype(t.layout), GroupConvLayout2D> ||
std::convertible_to<decltype(t.layout), GroupConvLayout3D>;
{ t.layout } -> ConvLayout;
{ t.data_type } -> std::convertible_to<DataType>;
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
{ t.device_operation } -> ConvDeviceOp;
};
// Concept to validate a convolution signature's values.
@@ -57,18 +63,7 @@ template <auto Sig>
concept ValidConvSignature = requires {
requires ConvSpatialDim<Sig.spatial_dim>;
requires ConvDataType<Sig.data_type>;
requires IsValidConvDeviceOp<Sig>;
};
// Predicate for forward convolution.
template <auto Sig>
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
// Predicate for backward data convolution.
template <auto Sig>
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
// Predicate for backward weight convolution.
template <auto Sig>
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
} // namespace ck_tile::builder

View File

@@ -0,0 +1,174 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
/**********************************************
* Conv Direction Predicates
**********************************************/
// Predicate for forward convolution.
template <auto Sig>
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
// Predicate for backward data convolution.
template <auto Sig>
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
// Predicate for backward weight convolution.
template <auto Sig>
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
/**********************************************
* Conv Fwd Device Op Predicates
**********************************************/
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3);
// Predicate for DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK);
// Predicate for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
(Sig.device_operation._fwd ==
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor);
// Generic predicate to check if signature uses any forward convolution device operation.
template <auto Sig>
concept ConvDeviceOpIsForward =
ConvDeviceOpIs_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<Sig>;
/**********************************************
* Conv Bwd Weight Device Op Predicates
**********************************************/
// Predicate for DeviceGroupedConvBwdWeight operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight);
// Predicate for DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Wmma_CShuffle);
// Predicate for DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3 =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3);
// Predicate for DeviceGroupedConvBwdWeightMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeightMultipleD);
// Predicate for DeviceGroupedConvBwdWeight_Dl operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl =
(Sig.device_operation._bwd_weight ==
BwdWeightGroupConvDeviceOperation::DeviceGroupedConvBwdWeight_Dl);
// Generic predicate to check if signature uses any backward weight convolution device operation.
template <auto Sig>
concept ConvDeviceOpIsBackwardWeight =
ConvDeviceOpIs_DeviceGroupedConvBwdWeight<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Wmma_CShuffle<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeightMultipleD<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdWeight_Dl<Sig>;
/**********************************************
* Conv Bwd Data Device Op Predicates
**********************************************/
// Predicate for DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1 =
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1);
// Predicate for DeviceGroupedConvBwdDataMultipleD operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD =
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD);
// Predicate for DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle operation.
template <auto Sig>
concept ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle =
(Sig.device_operation._bwd_data ==
BwdDataGroupConvDeviceOperation::DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle);
// Generic predicate to check if signature uses any backward data convolution device operation.
template <auto Sig>
concept ConvDeviceOpIsBackwardData =
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD<Sig> ||
ConvDeviceOpIs_DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle<Sig>;
/**********************************************
* Generic Device Op Predicates
**********************************************/
// Generic predicate to check if signature uses any device operation.
template <auto Sig>
concept IsValidConvDeviceOp = ConvDeviceOpIsForward<Sig> || ConvDeviceOpIsBackwardData<Sig> ||
ConvDeviceOpIsBackwardWeight<Sig>;
} // namespace ck_tile::builder

View File

@@ -69,7 +69,8 @@ template <ck::index_t NDimSpatial,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AComputeDataType,
typename BComputeDataType>
typename BComputeDataType,
bool DirectLoad>
struct DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
} // namespace ck::tensor_operation::device
@@ -124,7 +125,8 @@ template <ck::index_t NDimSpatial,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer,
typename AComputeDataType_,
typename BComputeDataType_>
typename BComputeDataType_,
bool DirectLoad>
struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
NDimSpatial,
ALayout_,
@@ -173,7 +175,8 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
BlkGemmPipeSched,
BlkGemmPipelineVer,
AComputeDataType_,
BComputeDataType_>>
BComputeDataType_,
DirectLoad>>
{
// Spatial dimension
static constexpr int kSpatialDim = NDimSpatial;
@@ -336,6 +339,7 @@ struct InstanceTraits<ck::tensor_operation::device::DeviceGroupedConvFwdMultiple
oss << "," << detail::pipeline_version_name(kPipelineVersion); // 46. BlkGemmPipelineVer
oss << "," << detail::type_name<AComputeDataType>(); // 47. AComputeDataType
oss << "," << detail::type_name<BComputeDataType>(); // 48. BComputeDataType
oss << "," << (DirectLoad ? "true" : "false"); // 49. DirectLoad
oss << ">";
return oss.str();

View File

@@ -48,6 +48,20 @@ enum class GroupConvLayout3D
NGCDHW_GKCZYX_NGKDHW,
};
struct GroupConvLayout
{
union
{
GroupConvLayout1D _1d;
GroupConvLayout2D _2d;
GroupConvLayout3D _3d;
};
constexpr GroupConvLayout(GroupConvLayout1D layout) : _1d(layout) {}
constexpr GroupConvLayout(GroupConvLayout2D layout) : _2d(layout) {}
constexpr GroupConvLayout(GroupConvLayout3D layout) : _3d(layout) {}
};
// Direction of the convolution operation.
enum class ConvDirection
{
@@ -56,6 +70,52 @@ enum class ConvDirection
BACKWARD_WEIGHT
};
// Forward convolution device operations.
enum class FwdGroupConvDeviceOperation
{
DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK,
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle,
DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3,
DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
};
// Backward data convolution device operations.
enum class BwdDataGroupConvDeviceOperation
{
DeviceGroupedConvBwdDataMultipleD,
DeviceGroupedConvBwdDataMultipleD_Wmma_CShuffle,
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1
};
// Backward weight convolution device operations.
enum class BwdWeightGroupConvDeviceOperation
{
DeviceGroupedConvBwdWeight,
DeviceGroupedConvBwdWeight_Dl,
DeviceGroupedConvBwdWeight_Xdl_CShuffle,
DeviceGroupedConvBwdWeight_Xdl_CShuffleV3,
DeviceGroupedConvBwdWeight_Wmma_CShuffle,
DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle,
DeviceGroupedConvBwdWeightMultipleD,
DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle,
};
// Structural type for device operation
struct GroupConvDeviceOp
{
union
{
FwdGroupConvDeviceOperation _fwd;
BwdDataGroupConvDeviceOperation _bwd_data;
BwdWeightGroupConvDeviceOperation _bwd_weight;
};
constexpr GroupConvDeviceOp(FwdGroupConvDeviceOperation op) : _fwd(op) {}
constexpr GroupConvDeviceOp(BwdDataGroupConvDeviceOperation op) : _bwd_data(op) {}
constexpr GroupConvDeviceOp(BwdWeightGroupConvDeviceOperation op) : _bwd_weight(op) {}
};
// Fused element-wise operations.
enum class ElementwiseOperation
{

View File

@@ -9,12 +9,14 @@ namespace ck_tile::builder::testing {
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_1D_BF16_ChannelsFirst_scale)
{
constexpr ConvSignature<GroupConvLayout1D> FwdConvSignature{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 1,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout1D::NGCW_GKXC_NGKW,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::SCALE};
.elementwise_operation = ElementwiseOperation::SCALE,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};

View File

@@ -8,12 +8,14 @@ namespace ck_tile::builder::testing {
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
@@ -28,12 +30,14 @@ TEST(FwdConvInstances,
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};

View File

@@ -7,12 +7,14 @@ namespace ck_tile::builder::testing {
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};

View File

@@ -7,12 +7,14 @@ namespace ck_tile::builder::testing {
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
.data_type = DataType::FP32,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};

View File

@@ -8,12 +8,14 @@ namespace ck_tile::builder::testing {
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
{
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};

View File

@@ -8,12 +8,14 @@ namespace ck_tile::builder::testing {
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
{
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};

View File

@@ -8,12 +8,14 @@ namespace ck_tile::builder::testing {
TEST(FwdConvInstances,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
{
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
constexpr ConvSignature FwdConvSignature{
.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
.data_type = DataType::FP32,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
.elementwise_operation = ElementwiseOperation::PASS_THROUGH,
.device_operation =
FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};

View File

@@ -3,11 +3,13 @@
#pragma once
#include <variant>
#include "ck_tile/builder/conv_signature_concepts.hpp"
namespace ck_tile::builder::test {
template <typename GroupConvLayout>
using namespace ck_tile::builder;
struct ConvSignature
{
int spatial_dim;
@@ -15,9 +17,8 @@ struct ConvSignature
GroupConvLayout layout;
DataType data_type;
ElementwiseOperation elementwise_operation;
GroupConvDeviceOp device_operation;
};
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout1D>>);
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout2D>>);
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout3D>>);
static_assert(ConvSignatureDescriptor<ConvSignature>);
} // namespace ck_tile::builder::test

View File

@@ -11,7 +11,7 @@ using namespace ck_tile::builder;
using namespace test;
// Common test implementation
template <auto FwdConvSignature,
template <ConvSignature FwdConvSignature,
ThreadBlock FwdThreadBlock,
BlockGemmPipelineVersion FwdPipelineVersion,
ConvFwdSpecialization FwdConvSpecialization>