mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
ck-builder: tensor input/output reflection (#3536)
This adds some utilities to automatically generate UniqueInputs, UniqueOutputs, alloc_inputs, alloc_outputs, and validate, based on a Inputs::reflect() and Outputs::reflect().
This commit is contained in:
@@ -4,6 +4,7 @@
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_foreach.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <vector>
|
||||
@@ -12,6 +13,7 @@ namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
|
||||
using ::testing::ElementsAreArray;
|
||||
using ::testing::Eq;
|
||||
using ::testing::NotNull;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
@@ -57,6 +59,8 @@ using UniqueOutputs = ckt::UniqueOutputs<SIGNATURE>;
|
||||
|
||||
static_assert(ckt::ValidUniqueInputs<SIGNATURE>);
|
||||
static_assert(ckt::ValidUniqueOutputs<SIGNATURE>);
|
||||
static_assert(ckt::TensorReflectable<Inputs, SIGNATURE>);
|
||||
static_assert(ckt::TensorReflectable<Outputs, SIGNATURE>);
|
||||
|
||||
TEST(ConvFwdTesting, MakeDescriptors)
|
||||
{
|
||||
@@ -81,3 +85,41 @@ TEST(ConvFwdTesting, Alloc)
|
||||
EXPECT_THAT(inputs.get().weight, NotNull());
|
||||
EXPECT_THAT(outputs.get().output, NotNull());
|
||||
}
|
||||
|
||||
TEST(ConvFwdTesting, Validate)
|
||||
{
|
||||
auto a = alloc_outputs(ARGS);
|
||||
auto b = alloc_outputs(ARGS);
|
||||
|
||||
// Positive test
|
||||
{
|
||||
ckt::Outputs<SIGNATURE>::reflect(
|
||||
ARGS,
|
||||
[&]([[maybe_unused]] std::string_view name,
|
||||
const auto& desc,
|
||||
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{123});
|
||||
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{123});
|
||||
});
|
||||
|
||||
const auto report = ckt::validate(ARGS, a.get(), b.get());
|
||||
EXPECT_THAT(report.get_errors().size(), Eq(0));
|
||||
}
|
||||
|
||||
// Negative test
|
||||
{
|
||||
size_t field_count = 0;
|
||||
ckt::Outputs<SIGNATURE>::reflect(
|
||||
ARGS,
|
||||
[&]([[maybe_unused]] std::string_view name,
|
||||
const auto& desc,
|
||||
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
||||
++field_count;
|
||||
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{2});
|
||||
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{1});
|
||||
});
|
||||
|
||||
const auto report = ckt::validate(ARGS, a.get(), b.get());
|
||||
EXPECT_THAT(report.get_errors().size(), Eq(field_count));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user