adressed review comments from PR3459 (#3526)

Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
This commit is contained in:
kabrahamAMD
2026-01-12 09:47:00 +01:00
committed by GitHub
parent b352a68606
commit 20f66c1e6b
2 changed files with 20 additions and 11 deletions

View File

@@ -85,21 +85,23 @@ The top-level signature contains global properties that apply to the entire conv
template <typename T>
concept ConvSignatureDescriptor = requires(T t) {
{ t.spatial_dim } -> std::convertible_to<unsigned int>; // 1, 2, or 3
{ t.data_type } -> std::convertible_to<DataType>; // Default data type
{ t.input } -> ConvTensorDescriptor;
{ t.weight } -> ConvTensorDescriptor;
{ t.output } -> ConvTensorDescriptor;
requires ConvolutionDirectionWellDefinedIfProvided<T>; // Optional direction
requires detail::DataTypeWellDefinedIfProvided<T>; // Optional default data type
requires detail::ElementwiseOpWellDefinedIfProvided<T>; // Optional default elementwise operation
};
```
**Properties:**
- **`spatial_dim`**: Dimensionality of the convolution (1D, 2D, or 3D)
- **`direction`**: Operation type (optional, defaults to FORWARD)
- **`direction`**: Operation type (Optional, defaults to FORWARD)
- `FORWARD`: Standard forward convolution
- `BACKWARD_DATA`: Gradient computation w.r.t. input
- `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights
- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8)
- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8). (Optional, defaults to UNDEFINED_DATA_TYPE, may be overridden by tensors)
- **`operation`**: Default Operation (Optional, defaults to PASS_THROUGH, may be overridden by tensors)
- **`accumulation_data_type`**: Type used for internal accumulation
#### 2. Tensor Level
@@ -116,7 +118,7 @@ concept ConvTensorDescriptor = requires(T t) {
A tensor descriptor encapsulates:
- **Configuration**: Layout and data type information
- **Operation** (optional): Fused elementwise operations on this tensor
- **operation** Fused elementwise operations on this tensor (Optional, default provided by ConvSignatureDescriptor)
#### 3. Tensor Configuration
@@ -126,7 +128,7 @@ Describes the memory layout and data types:
template <typename T>
concept TensorConfigDescriptor = requires(T t) {
{ t.layout } -> std::convertible_to<ConvLayout>;
{ t.data_type } -> std::convertible_to<DataType>; // Optional override
requires detail::DataTypeWellDefinedIfProvided<T>; // Override data type (Optional, default provided by ConvSignatureDescriptor)
};
```

View File

@@ -80,6 +80,7 @@ concept ConvOutputLayout3D =
(L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) ||
(L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided);
namespace detail {
template <typename T>
concept HasDataType = requires(T t) {
{ t.data_type };
@@ -94,10 +95,11 @@ concept DataTypeWellDefinedIfProvided = requires(T t) {
};
};
} // namespace detail
template <typename T>
concept TensorConfigDescriptor = requires(T t) {
{ t.layout } -> std::convertible_to<TensorLayout>;
requires DataTypeWellDefinedIfProvided<T>;
requires detail::DataTypeWellDefinedIfProvided<T>;
};
template <typename T>
@@ -116,7 +118,6 @@ template <typename T, std::size_t N>
struct IsArrayOfTensorConfigDescriptors<std::array<T, N>> : std::true_type
{
};
} // namespace detail
template <typename T>
concept ConvertibleToArrayOfTensorConfigs =
@@ -128,11 +129,12 @@ concept AuxiliaryOperandConfigsWellDefinedIfProvided = requires(T t) {
{ t.auxiliary_operand_configs } -> ConvertibleToArrayOfTensorConfigs;
};
};
} // namespace detail
template <typename T>
concept TensorOperatorDescriptor = requires(T t) {
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
requires AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
requires detail::AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
};
template <typename T>
@@ -140,6 +142,8 @@ concept HasTensorOp = requires(T t) {
{ t.operation };
};
namespace detail {
template <typename T>
concept HasConvolutionDirection = requires(T t) {
{ t.direction };
@@ -159,11 +163,13 @@ concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) {
};
};
} // namespace detail
// Concept for the convolution tensor
template <typename T>
concept ConvTensorDescriptor = requires(T t) {
{ t.config } -> TensorConfigDescriptor;
requires ElementwiseOpWellDefinedIfProvided<T>;
requires detail::ElementwiseOpWellDefinedIfProvided<T>;
};
template <typename T>
@@ -179,8 +185,9 @@ concept ConvSignatureDescriptor = requires(T t) {
{ t.input } -> ConvTensorDescriptor;
{ t.weight } -> ConvTensorDescriptor;
{ t.output } -> ConvTensorDescriptor;
requires ConvolutionDirectionWellDefinedIfProvided<T>;
requires DataTypeWellDefinedIfProvided<T>;
requires detail::ConvolutionDirectionWellDefinedIfProvided<T>;
requires detail::DataTypeWellDefinedIfProvided<T>;
requires detail::ElementwiseOpWellDefinedIfProvided<T>;
};
// Concept to validate a convolution signature's values.