mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_Builder] removed direction and elementwise_operation from required parameters … (#3192)
Removed direction and elementwise operation from default values required for convolution signature concept. Added constexpr helpers to set default values. Add compile-time tests.
This commit is contained in:
@@ -563,7 +563,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load ==
|
||||
|
||||
@@ -43,14 +43,41 @@ concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == D
|
||||
template <typename T>
|
||||
concept ConvLayout = std::same_as<std::remove_cvref_t<T>, GroupConvLayout>;
|
||||
|
||||
template <typename T>
|
||||
concept HasElementwiseOp = requires(T t) {
|
||||
{ t.elementwise_operation };
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept HasConvolutionDirection = requires(T t) {
|
||||
{ t.direction };
|
||||
};
|
||||
|
||||
// Note: it is not required to provide an ElementwiseOp, but if one is provided, check if well
|
||||
// defined
|
||||
template <typename T>
|
||||
concept ElementwiseOpWellDefinedIfProvided = requires(T t) {
|
||||
requires !HasElementwiseOp<T> || requires {
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
};
|
||||
};
|
||||
|
||||
// Note: it is not required to provide a convolution, but if one is provided, check if well defined
|
||||
template <typename T>
|
||||
concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) {
|
||||
requires !HasConvolutionDirection<T> || requires {
|
||||
{ t.direction } -> std::convertible_to<ConvDirection>;
|
||||
};
|
||||
};
|
||||
|
||||
// Concept for a type that defines a convolution's operational signature.
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
|
||||
{ t.direction } -> std::convertible_to<ConvDirection>;
|
||||
{ t.layout } -> ConvLayout;
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
requires ElementwiseOpWellDefinedIfProvided<T>;
|
||||
requires ConvolutionDirectionWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
// Concept to validate a convolution signature's values.
|
||||
|
||||
@@ -19,6 +19,17 @@ namespace ckt = ck_tile::test;
|
||||
// Defines the signature of the convolution operation to be tested.
|
||||
// This includes dimensionality, direction, data layout, and data type.
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::GroupConvDeviceOp device_operation =
|
||||
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
// Compile time tests for concepts
|
||||
struct ConvSignatureWithOptionalParams
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
|
||||
@@ -26,7 +37,19 @@ struct ConvSignature
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
ckb::ElementwiseOperation elementwise_operation = ckb::ElementwiseOperation::PASS_THROUGH;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignatureWithOptionalParams>);
|
||||
|
||||
struct ConvSignatureWithInvalidOptionalParams
|
||||
{
|
||||
int spatial_dim = 2;
|
||||
ckb::ConvDirection direction = ckb::ConvDirection::FORWARD;
|
||||
ckb::GroupConvLayout layout = ckb::GroupConvLayout2D::GNHWC_GKYXC_GNHWK;
|
||||
ckb::DataType data_type = ckb::DataType::FP16;
|
||||
int elementwise_operation = 7; // this should fail
|
||||
ckb::GroupConvDeviceOp device_operation =
|
||||
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
};
|
||||
static_assert(!ckb::ConvSignatureDescriptor<ConvSignatureWithInvalidOptionalParams>);
|
||||
|
||||
struct DefaultAlgorithm
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user