ck-builder: tensor input/output reflection (#3536)

This adds some utilities to automatically generate UniqueInputs,
UniqueOutputs, alloc_inputs, alloc_outputs, and validate, based
on a Inputs::reflect() and Outputs::reflect().
This commit is contained in:
Robin Voetter
2026-01-12 09:45:53 +01:00
committed by GitHub
parent 32408c8bc0
commit b352a68606
8 changed files with 299 additions and 102 deletions

View File

@@ -5,6 +5,8 @@
#include <concepts>
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
#include "ck_tile/builder/testing/validation.hpp"
/// This file is the main header for the CK-Builder testing system. A high-level
@@ -132,8 +134,8 @@ struct Outputs;
/// be created using `alloc_inputs()` and that an instance of the corresponding
/// `Inputs` structure can be obtained using `.get()`.
///
/// @note The easiest way to implement this type is to use the `DeviceBuffer`
/// type to allocate individual device buffers for each input tensor.
/// @note A default implementation is provided for this type if `Inputs`
/// supports `TensorReflectable`.
///
/// @tparam SIGNATURE The signature to specialize the structure for.
///
@@ -151,8 +153,8 @@ struct UniqueInputs;
/// be created using `alloc_outputs()` and that an instance of the corresponding
/// `Outputs` structure can be obtained using `.get()`.
///
/// @note The easiest way to implement this type is to use the `DeviceBuffer`
/// type to allocate individual device buffers for each output tensor.
/// @note A default implementation is provided for this type if `Outputs`
/// supports `TensorReflectable`.
///
/// @tparam SIGNATURE The signature to specialize the structure for.
///
@@ -197,6 +199,12 @@ concept ValidUniqueOutputs = requires(UniqueOutputs<SIGNATURE>& inputs) {
/// amount of memory required and then allocate it on the device, for example
/// using `alloc_buffer` or `alloc_tensor_buffer`.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @note A default implementation is provided for this function if `Inputs`
/// supports `TensorReflectable`.
///
/// @tparam SIGNATURE The signature to specialize the structure for.
///
/// @param args The run-time arguments of the operation.
@@ -207,22 +215,22 @@ concept ValidUniqueOutputs = requires(UniqueOutputs<SIGNATURE>& inputs) {
/// @see alloc_tensor_buffer()
template <auto SIGNATURE>
requires ValidUniqueInputs<SIGNATURE>
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args);
UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args) = delete;
/// @brief Allocate inputs corresponding to a signature.
/// @brief Initialize inputs corresponding to a signature.
///
/// The `init_inputs()` function is used to initialize pseudo-random data
/// to the tensors specified in the Inputs structure. Implementors should
/// fill each of the tensors in `inputs` with appropriate random data.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @tparam SIGNATURE the signature to specialize the structure for.
///
/// @param args The run-time arguments of the operation.
/// @param inputs The operation inputs to initialize with random data.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see Inputs
/// @see tensor_initialization
template <auto SIGNATURE>
@@ -235,13 +243,16 @@ void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs) = delete
/// amount of memory required and then allocate it on the device, for example
/// using `alloc_buffer` or `alloc_tensor_buffer`.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @note A default implementation is provided for this function if `Outputs`
/// supports `TensorReflectable`.
///
/// @tparam SIGNATURE The signature to specialize the structure for.
///
/// @param args The run-time arguments of the operation.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see Outputs
/// @see UniqueOutputs
/// @see alloc_buffer()
@@ -262,15 +273,15 @@ UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args) = delete;
/// were incorrect, and where (a subset of) those elements are located within
/// the tensor. See `ValidationReport` for more information about the report.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @tparam SIGNATURE The signature to specialize the structure for.
///
/// @param args The run-time arguments of the operation.
/// @param actual The actual results, the results of the operation to-be-tested.
/// @param expected The expected results, the results of the reference implementation.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see ValidationReport
template <auto SIGNATURE>
ValidationReport validate(const Args<SIGNATURE>& args,