ck-builder: tensor input/output reflection (#3536)

This adds some utilities to automatically generate UniqueInputs,
UniqueOutputs, alloc_inputs, alloc_outputs, and validate, based
on a Inputs::reflect() and Outputs::reflect().
This commit is contained in:
Robin Voetter
2026-01-12 09:45:53 +01:00
committed by GitHub
parent 32408c8bc0
commit b352a68606
8 changed files with 299 additions and 102 deletions

View File

@@ -4,6 +4,7 @@
#include "impl/conv_signature_types.hpp"
#include "testing_utils.hpp"
#include "ck_tile/builder/testing/conv_fwd.hpp"
#include "ck_tile/builder/testing/tensor_foreach.hpp"
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <vector>
@@ -12,6 +13,7 @@ namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
using ::testing::ElementsAreArray;
using ::testing::Eq;
using ::testing::NotNull;
constexpr auto SIGNATURE =
@@ -57,6 +59,8 @@ using UniqueOutputs = ckt::UniqueOutputs<SIGNATURE>;
static_assert(ckt::ValidUniqueInputs<SIGNATURE>);
static_assert(ckt::ValidUniqueOutputs<SIGNATURE>);
static_assert(ckt::TensorReflectable<Inputs, SIGNATURE>);
static_assert(ckt::TensorReflectable<Outputs, SIGNATURE>);
TEST(ConvFwdTesting, MakeDescriptors)
{
@@ -81,3 +85,41 @@ TEST(ConvFwdTesting, Alloc)
EXPECT_THAT(inputs.get().weight, NotNull());
EXPECT_THAT(outputs.get().output, NotNull());
}
TEST(ConvFwdTesting, Validate)
{
auto a = alloc_outputs(ARGS);
auto b = alloc_outputs(ARGS);
// Positive test
{
ckt::Outputs<SIGNATURE>::reflect(
ARGS,
[&]([[maybe_unused]] std::string_view name,
const auto& desc,
void* ckt::Outputs<SIGNATURE>::*ptr) {
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{123});
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{123});
});
const auto report = ckt::validate(ARGS, a.get(), b.get());
EXPECT_THAT(report.get_errors().size(), Eq(0));
}
// Negative test
{
size_t field_count = 0;
ckt::Outputs<SIGNATURE>::reflect(
ARGS,
[&]([[maybe_unused]] std::string_view name,
const auto& desc,
void* ckt::Outputs<SIGNATURE>::*ptr) {
++field_count;
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{2});
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{1});
});
const auto report = ckt::validate(ARGS, a.get(), b.get());
EXPECT_THAT(report.get_errors().size(), Eq(field_count));
}
}