mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
adressed review comments from PR3459 (#3526)
Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
This commit is contained in:
@@ -85,21 +85,23 @@ The top-level signature contains global properties that apply to the entire conv
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<unsigned int>; // 1, 2, or 3
|
||||
{ t.data_type } -> std::convertible_to<DataType>; // Default data type
|
||||
{ t.input } -> ConvTensorDescriptor;
|
||||
{ t.weight } -> ConvTensorDescriptor;
|
||||
{ t.output } -> ConvTensorDescriptor;
|
||||
requires ConvolutionDirectionWellDefinedIfProvided<T>; // Optional direction
|
||||
requires detail::DataTypeWellDefinedIfProvided<T>; // Optional default data type
|
||||
requires detail::ElementwiseOpWellDefinedIfProvided<T>; // Optional default elementwise operation
|
||||
};
|
||||
```
|
||||
|
||||
**Properties:**
|
||||
- **`spatial_dim`**: Dimensionality of the convolution (1D, 2D, or 3D)
|
||||
- **`direction`**: Operation type (optional, defaults to FORWARD)
|
||||
- **`direction`**: Operation type (Optional, defaults to FORWARD)
|
||||
- `FORWARD`: Standard forward convolution
|
||||
- `BACKWARD_DATA`: Gradient computation w.r.t. input
|
||||
- `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights
|
||||
- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8)
|
||||
- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8). (Optional, defaults to UNDEFINED_DATA_TYPE, may be overridden by tensors)
|
||||
- **`operation`**: Default Operation (Optional, defaults to PASS_THROUGH, may be overridden by tensors)
|
||||
- **`accumulation_data_type`**: Type used for internal accumulation
|
||||
|
||||
#### 2. Tensor Level
|
||||
@@ -116,7 +118,7 @@ concept ConvTensorDescriptor = requires(T t) {
|
||||
|
||||
A tensor descriptor encapsulates:
|
||||
- **Configuration**: Layout and data type information
|
||||
- **Operation** (optional): Fused elementwise operations on this tensor
|
||||
- **operation** Fused elementwise operations on this tensor (Optional, default provided by ConvSignatureDescriptor)
|
||||
|
||||
#### 3. Tensor Configuration
|
||||
|
||||
@@ -126,7 +128,7 @@ Describes the memory layout and data types:
|
||||
template <typename T>
|
||||
concept TensorConfigDescriptor = requires(T t) {
|
||||
{ t.layout } -> std::convertible_to<ConvLayout>;
|
||||
{ t.data_type } -> std::convertible_to<DataType>; // Optional override
|
||||
requires detail::DataTypeWellDefinedIfProvided<T>; // Override data type (Optional, default provided by ConvSignatureDescriptor)
|
||||
};
|
||||
```
|
||||
|
||||
|
||||
@@ -80,6 +80,7 @@ concept ConvOutputLayout3D =
|
||||
(L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) ||
|
||||
(L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided);
|
||||
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
concept HasDataType = requires(T t) {
|
||||
{ t.data_type };
|
||||
@@ -94,10 +95,11 @@ concept DataTypeWellDefinedIfProvided = requires(T t) {
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
template <typename T>
|
||||
concept TensorConfigDescriptor = requires(T t) {
|
||||
{ t.layout } -> std::convertible_to<TensorLayout>;
|
||||
requires DataTypeWellDefinedIfProvided<T>;
|
||||
requires detail::DataTypeWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -116,7 +118,6 @@ template <typename T, std::size_t N>
|
||||
struct IsArrayOfTensorConfigDescriptors<std::array<T, N>> : std::true_type
|
||||
{
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
concept ConvertibleToArrayOfTensorConfigs =
|
||||
@@ -128,11 +129,12 @@ concept AuxiliaryOperandConfigsWellDefinedIfProvided = requires(T t) {
|
||||
{ t.auxiliary_operand_configs } -> ConvertibleToArrayOfTensorConfigs;
|
||||
};
|
||||
};
|
||||
} // namespace detail
|
||||
|
||||
template <typename T>
|
||||
concept TensorOperatorDescriptor = requires(T t) {
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
requires AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
|
||||
requires detail::AuxiliaryOperandConfigsWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -140,6 +142,8 @@ concept HasTensorOp = requires(T t) {
|
||||
{ t.operation };
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename T>
|
||||
concept HasConvolutionDirection = requires(T t) {
|
||||
{ t.direction };
|
||||
@@ -159,11 +163,13 @@ concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) {
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
// Concept for the convolution tensor
|
||||
template <typename T>
|
||||
concept ConvTensorDescriptor = requires(T t) {
|
||||
{ t.config } -> TensorConfigDescriptor;
|
||||
requires ElementwiseOpWellDefinedIfProvided<T>;
|
||||
requires detail::ElementwiseOpWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@@ -179,8 +185,9 @@ concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.input } -> ConvTensorDescriptor;
|
||||
{ t.weight } -> ConvTensorDescriptor;
|
||||
{ t.output } -> ConvTensorDescriptor;
|
||||
requires ConvolutionDirectionWellDefinedIfProvided<T>;
|
||||
requires DataTypeWellDefinedIfProvided<T>;
|
||||
requires detail::ConvolutionDirectionWellDefinedIfProvided<T>;
|
||||
requires detail::DataTypeWellDefinedIfProvided<T>;
|
||||
requires detail::ElementwiseOpWellDefinedIfProvided<T>;
|
||||
};
|
||||
|
||||
// Concept to validate a convolution signature's values.
|
||||
|
||||
Reference in New Issue
Block a user