diff --git a/experimental/builder/include/ck_tile/builder/README.md b/experimental/builder/include/ck_tile/builder/README.md index 8075e33220..af8c4ec01b 100644 --- a/experimental/builder/include/ck_tile/builder/README.md +++ b/experimental/builder/include/ck_tile/builder/README.md @@ -85,21 +85,23 @@ The top-level signature contains global properties that apply to the entire conv template concept ConvSignatureDescriptor = requires(T t) { { t.spatial_dim } -> std::convertible_to; // 1, 2, or 3 - { t.data_type } -> std::convertible_to; // Default data type { t.input } -> ConvTensorDescriptor; { t.weight } -> ConvTensorDescriptor; { t.output } -> ConvTensorDescriptor; requires ConvolutionDirectionWellDefinedIfProvided; // Optional direction + requires detail::DataTypeWellDefinedIfProvided; // Optional default data type + requires detail::ElementwiseOpWellDefinedIfProvided; // Optional default elementwise operation }; ``` **Properties:** - **`spatial_dim`**: Dimensionality of the convolution (1D, 2D, or 3D) -- **`direction`**: Operation type (optional, defaults to FORWARD) +- **`direction`**: Operation type (Optional, defaults to FORWARD) - `FORWARD`: Standard forward convolution - `BACKWARD_DATA`: Gradient computation w.r.t. input - `BACKWARD_WEIGHT`: Gradient computation w.r.t. weights -- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8) +- **`data_type`**: Default data type for all tensors (FP32, FP16, BF16, FP8, I8, U8). (Optional, defaults to UNDEFINED_DATA_TYPE, may be overridden by tensors) +- **`operation`**: Default Operation (Optional, defaults to PASS_THROUGH, may be overridden by tensors) - **`accumulation_data_type`**: Type used for internal accumulation #### 2. Tensor Level @@ -116,7 +118,7 @@ concept ConvTensorDescriptor = requires(T t) { A tensor descriptor encapsulates: - **Configuration**: Layout and data type information -- **Operation** (optional): Fused elementwise operations on this tensor +- **operation** Fused elementwise operations on this tensor (Optional, default provided by ConvSignatureDescriptor) #### 3. Tensor Configuration @@ -126,7 +128,7 @@ Describes the memory layout and data types: template concept TensorConfigDescriptor = requires(T t) { { t.layout } -> std::convertible_to; - { t.data_type } -> std::convertible_to; // Optional override + requires detail::DataTypeWellDefinedIfProvided; // Override data type (Optional, default provided by ConvSignatureDescriptor) }; ``` diff --git a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp index 39e081ec8d..f085283381 100644 --- a/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_signature_concepts.hpp @@ -80,6 +80,7 @@ concept ConvOutputLayout3D = (L == TensorLayout::GNKDHW) || (L == TensorLayout::GNDHWK) || (L == TensorLayout::NDHWGK) || (L == TensorLayout::NGKDHW) || (L == TensorLayout::G_NDHW_K_strided); +namespace detail { template concept HasDataType = requires(T t) { { t.data_type }; @@ -94,10 +95,11 @@ concept DataTypeWellDefinedIfProvided = requires(T t) { }; }; +} // namespace detail template concept TensorConfigDescriptor = requires(T t) { { t.layout } -> std::convertible_to; - requires DataTypeWellDefinedIfProvided; + requires detail::DataTypeWellDefinedIfProvided; }; template @@ -116,7 +118,6 @@ template struct IsArrayOfTensorConfigDescriptors> : std::true_type { }; -} // namespace detail template concept ConvertibleToArrayOfTensorConfigs = @@ -128,11 +129,12 @@ concept AuxiliaryOperandConfigsWellDefinedIfProvided = requires(T t) { { t.auxiliary_operand_configs } -> ConvertibleToArrayOfTensorConfigs; }; }; +} // namespace detail template concept TensorOperatorDescriptor = requires(T t) { { t.elementwise_operation } -> std::convertible_to; - requires AuxiliaryOperandConfigsWellDefinedIfProvided; + requires detail::AuxiliaryOperandConfigsWellDefinedIfProvided; }; template @@ -140,6 +142,8 @@ concept HasTensorOp = requires(T t) { { t.operation }; }; +namespace detail { + template concept HasConvolutionDirection = requires(T t) { { t.direction }; @@ -159,11 +163,13 @@ concept ConvolutionDirectionWellDefinedIfProvided = requires(T t) { }; }; +} // namespace detail + // Concept for the convolution tensor template concept ConvTensorDescriptor = requires(T t) { { t.config } -> TensorConfigDescriptor; - requires ElementwiseOpWellDefinedIfProvided; + requires detail::ElementwiseOpWellDefinedIfProvided; }; template @@ -179,8 +185,9 @@ concept ConvSignatureDescriptor = requires(T t) { { t.input } -> ConvTensorDescriptor; { t.weight } -> ConvTensorDescriptor; { t.output } -> ConvTensorDescriptor; - requires ConvolutionDirectionWellDefinedIfProvided; - requires DataTypeWellDefinedIfProvided; + requires detail::ConvolutionDirectionWellDefinedIfProvided; + requires detail::DataTypeWellDefinedIfProvided; + requires detail::ElementwiseOpWellDefinedIfProvided; }; // Concept to validate a convolution signature's values.