[CK_BUILDER] Integrate CKB validation with CK verification (#3649)

* ck-builder: tensor copy function

This function copies one tensor to another, so that the memory
layout can be changed between them.

* ck-builder: fix ck::bhalf literals

These types don't work properly.

* ck-builder: abstract compare_elements in gpu_verification.hpp and make builder use it

This reduces the amount of duplicated code a bit.

* ck-builder: add flat tensor iterator

This "iterator" type pretends to be a pointer, useful for passing
tensors to functions expecting pointer-like types.

* ck-builder: integrate validation with ck gpu verification

By templating the gpu_verify function over iterators, we can use
the new FlatTensorIterator to adapt the function to multi-
dimensional tensors without changing either implementation
too much.

* ck-builder: add check_by_accumulations

This changes the gpu_verification.hpp code to also accept "iterator"
types for the relevant gpu_verify and gpu_reduce_max functions.

* ck: fix test_gpu_verification GenerateRandomData for bhalf

is_integer_it<bhalf_t> yields true, but it is not actually
an integer.

* ck: make gpu_verification kernels be proper persistent kernels

Previously these were using a hardcoded value for the grid size. This
commit changes that so that the grid size is automatically derived
from the kernel's occupancy and the number of multiprocessors on
the GPU.

* ck: clean up gpu_verification.hpp using block_reduce

This implements a small generic block reduce function, and rewrites
the rest of gpu_verification.hpp using that function to clean it up
a bit.

* ck-builder: doc typos

* ck-builder: update testing readme with validation interface.

* ck-builder: rebase fixes + review comments

* ck-builder: fix device integer generation with float types

Passing bfloat here causes a nans due to type_convert performing
a bitcast.

* ck: another bhalf_t bug

CK expects that int-generation with ck::bhalf_t yields bhalf integers,
not unsigned integers. This makes the logic of FillUniformRandInteger
compatible with GeneratorTensor_2<InDataType>, however idiotic that
may be.
This commit is contained in:
Robin Voetter
2026-01-28 17:41:02 +01:00
committed by GitHub
parent d6cccf6093
commit 42048bdb7d
11 changed files with 636 additions and 291 deletions

View File

@@ -209,7 +209,8 @@ struct ReferenceOutputMatcher
// Round to 2 digits
const float percentage = e.wrong_elements * 10000 / e.total_elements / 100.f;
*listener << e.wrong_elements << "/" << e.total_elements
<< " incorrect elements (~" << percentage << "%)";
<< " incorrect elements (~" << percentage << "%)," << " max error "
<< e.max_error;
}
}
}

View File

@@ -98,8 +98,10 @@ TEST(ConvFwdTesting, Validate)
[&]([[maybe_unused]] std::string_view name,
const auto& desc,
void* ckt::Outputs<SIGNATURE>::*ptr) {
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{123});
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{123});
ckt::clear_tensor_buffer(
desc, a.get().*ptr, ck::type_convert<ck::bhalf_t, float>(123));
ckt::clear_tensor_buffer(
desc, b.get().*ptr, ck::type_convert<ck::bhalf_t, float>(123));
});
const auto report = ckt::validate(ARGS, a.get(), b.get());
@@ -115,8 +117,10 @@ TEST(ConvFwdTesting, Validate)
const auto& desc,
void* ckt::Outputs<SIGNATURE>::*ptr) {
++field_count;
ckt::clear_tensor_buffer(desc, a.get().*ptr, ck::bhalf_t{2});
ckt::clear_tensor_buffer(desc, b.get().*ptr, ck::bhalf_t{1});
ckt::clear_tensor_buffer(
desc, a.get().*ptr, ck::type_convert<ck::bhalf_t, float>(2));
ckt::clear_tensor_buffer(
desc, b.get().*ptr, ck::type_convert<ck::bhalf_t, float>(1));
});
const auto report = ckt::validate(ARGS, a.get(), b.get());

View File

@@ -225,3 +225,99 @@ TEST(TensorForeach, ClearTensorZeros)
EXPECT_THAT(actual, Eq(0));
}
TEST(TensorForeach, CopyTensor)
{
constexpr auto dt = ckb::DataType::I32;
const ckt::Extent shape = {10, 3, 45, 23, 6};
using Counter = uint32_t;
const auto src_desc = ckt::make_descriptor<dt>(shape, ckt::PackedRightLayout{});
const auto dst_desc = ckt::make_descriptor<dt>(shape, ckt::PackedLeftLayout{});
auto src_buffer = ckt::alloc_tensor_buffer(src_desc);
auto dst_buffer = ckt::alloc_tensor_buffer(dst_desc);
const auto gen = [](const auto& index, const auto& lengths) {
// Simple incrementing counter
return static_cast<Counter>(ckt::calculate_offset(index, lengths));
};
ckt::fill_tensor(
src_desc, src_buffer.get(), [lengths = src_desc.get_lengths(), gen](const auto& index) {
return gen(index, lengths);
});
ckt::clear_tensor_buffer(dst_desc, dst_buffer.get());
// Perform the actual test
ckt::copy_tensor(src_desc, src_buffer.get(), dst_desc, dst_buffer.get());
// Check that the dst tensor has the same data
auto d_invalid = ckt::alloc_buffer(sizeof(Counter));
ckt::check_hip(hipMemset(d_invalid.get(), 0, sizeof(Counter)));
ckt::tensor_foreach(shape,
[lengths = dst_desc.get_lengths(),
gen,
dst = dst_buffer.get(),
invalid = reinterpret_cast<Counter*>(d_invalid.get()),
strides = dst_desc.get_strides()](const auto& index) {
const auto offset = ckt::calculate_offset(index, strides);
const auto expected = gen(index, lengths);
const auto actual = reinterpret_cast<const Counter*>(dst)[offset];
if(expected != actual)
atomicAdd(invalid, 1);
});
Counter invalid = 0;
ckt::check_hip(hipMemcpy(&invalid, d_invalid.get(), sizeof(Counter), hipMemcpyDeviceToHost));
EXPECT_THAT(invalid, Eq(0));
}
TEST(TensorForeach, FlatTensorIterator)
{
using Counter = uint32_t;
constexpr auto dt = ckb::DataType::I32;
const ckt::Extent shape = {10, 9, 8, 7, 6, 5, 4, 3, 2, 1};
const ckt::Extent packed_strides = ckt::PackedRightLayout{}(shape);
const auto desc = ckt::make_descriptor<dt>(shape, ckt::PackedLeftLayout{});
auto buffer = ckt::alloc_tensor_buffer(desc);
// Fill the tensor with random values according to the *flat* index. The
// FlatTensorIterator iterates over flat values even if the strides are not
// packed, so indexing these elements according to the flat index in the
// iterator should yield again this value.
ckt::fill_tensor(desc, buffer.get(), [packed_strides](const auto& index) {
const auto flat_index = ckt::calculate_offset(index, packed_strides);
return static_cast<int32_t>(flat_index * 10001 % 1001);
});
auto iterator = ckt::FlatTensorIterator(desc, reinterpret_cast<const int32_t*>(buffer.get()));
auto d_invalid = ckt::alloc_buffer(sizeof(Counter));
ckt::check_hip(hipMemset(d_invalid.get(), 0, sizeof(Counter)));
ckt::tensor_foreach(shape,
[iterator,
packed_strides,
strides = desc.get_strides(),
data = reinterpret_cast<const int32_t*>(buffer.get()),
invalid = reinterpret_cast<Counter*>(d_invalid.get())](const auto& index) {
const auto flat_index = ckt::calculate_offset(index, packed_strides);
const auto offset = ckt::calculate_offset(index, strides);
if(iterator[flat_index] != data[offset])
atomicAdd(invalid, 1);
});
Counter invalid = 0;
ckt::check_hip(hipMemcpy(&invalid, d_invalid.get(), sizeof(Counter), hipMemcpyDeviceToHost));
EXPECT_THAT(invalid, Eq(0));
}

View File

@@ -74,7 +74,8 @@ TYPED_TEST(ValidationReportTests, SingleCorrect)
ckt::fill_tensor(desc, b.get(), generator);
ckt::ValidationReport report;
report.check("correct", desc, b.get(), a.get());
report.check("correct - explicit tolerance", desc, b.get(), a.get());
report.check_by_accumulations("correct - implicit tolerance", desc, b.get(), a.get(), 0);
EXPECT_THAT(report.get_errors().size(), Eq(0));
}
@@ -97,17 +98,22 @@ TYPED_TEST(ValidationReportTests, SingleIncorrect)
});
ckt::ValidationReport report;
report.check("incorrect", desc, b.get(), a.get());
report.check("incorrect - explicit tolerance", desc, b.get(), a.get());
report.check_by_accumulations("incorrect - implicit tolerance", desc, b.get(), a.get(), 0);
const auto errors = report.get_errors();
const auto flat_size = desc.get_element_size();
const auto expected_errors = flat_size >= 999999 ? 3 : flat_size >= 12345 ? 2 : 1;
ASSERT_THAT(errors.size(), Eq(1));
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect"));
EXPECT_THAT(errors[0].wrong_elements, Eq(expected_errors));
EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size()));
ASSERT_THAT(errors.size(), Eq(2));
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect - explicit tolerance"));
EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect - implicit tolerance"));
for(int i = 0; i < 2; ++i)
{
EXPECT_THAT(errors[i].wrong_elements, Eq(expected_errors));
EXPECT_THAT(errors[i].total_elements, Eq(desc.get_element_size()));
}
}
TYPED_TEST(ValidationReportTests, ZeroIsIncorrect)
@@ -121,14 +127,20 @@ TYPED_TEST(ValidationReportTests, ZeroIsIncorrect)
ckt::clear_tensor_buffer(desc, b.get());
ckt::ValidationReport report;
report.check("zero_is_incorrect", desc, b.get(), a.get());
report.check("zero_is_incorrect - explicit tolerance", desc, b.get(), a.get());
report.check_by_accumulations(
"zero_is_incorrect - implicit tolerance", desc, b.get(), a.get(), 0);
const auto errors = report.get_errors();
ASSERT_THAT(errors.size(), Eq(1));
EXPECT_THAT(errors[0].tensor_name, StrEq("zero_is_incorrect"));
EXPECT_THAT(errors[0].wrong_elements, Eq(0));
EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size()));
EXPECT_THAT(errors[0].zero_elements, Eq(desc.get_element_size()));
ASSERT_THAT(errors.size(), Eq(2));
EXPECT_THAT(errors[0].tensor_name, StrEq("zero_is_incorrect - explicit tolerance"));
EXPECT_THAT(errors[1].tensor_name, StrEq("zero_is_incorrect - implicit tolerance"));
for(int i = 0; i < 2; ++i)
{
EXPECT_THAT(errors[i].wrong_elements, Eq(0));
EXPECT_THAT(errors[i].total_elements, Eq(desc.get_element_size()));
EXPECT_THAT(errors[i].both_all_zero, Eq(true));
}
}
TEST(ValidationReportTests, MultipleSomeIncorrect)
@@ -143,11 +155,12 @@ TEST(ValidationReportTests, MultipleSomeIncorrect)
auto b = ckt::alloc_tensor_buffer(desc);
ckt::fill_tensor_buffer(
desc, a.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(i % 100); });
desc, a.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(float(i % 100)); });
ckt::fill_tensor_buffer(
desc, b.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(i % 101); });
desc, b.get(), [](size_t i) { return ck::type_convert<ck::bhalf_t>(float(i % 101)); });
report.check("incorrect 1", desc, b.get(), a.get());
report.check("incorrect 1 - explicit tolerance", desc, b.get(), a.get());
report.check("incorrect 1 - implicit tolerance", desc, b.get(), a.get(), 0);
}
{
@@ -169,7 +182,8 @@ TEST(ValidationReportTests, MultipleSomeIncorrect)
}
});
report.check("correct", desc, b.get(), a.get());
report.check("correct - explicit tolerance", desc, b.get(), a.get());
report.check("correct - implicit tolerance", desc, b.get(), a.get(), 0);
}
{
@@ -182,16 +196,21 @@ TEST(ValidationReportTests, MultipleSomeIncorrect)
ckt::fill_tensor_buffer(desc, a.get(), []([[maybe_unused]] size_t i) { return 1; });
ckt::fill_tensor_buffer(desc, b.get(), []([[maybe_unused]] size_t i) { return 555; });
report.check("incorrect 2", desc, b.get(), a.get());
report.check("incorrect 2 - explicit tolerance", desc, b.get(), a.get());
report.check("incorrect 2 - implicit tolerance", desc, b.get(), a.get(), 0);
}
const auto errors = report.get_errors();
ASSERT_THAT(errors.size(), Eq(2));
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect 1"));
ASSERT_THAT(errors.size(), Eq(4));
EXPECT_THAT(errors[0].tensor_name, StrEq("incorrect 1 - explicit tolerance"));
EXPECT_THAT(errors[0].wrong_elements, Eq(46840334));
EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect 2"));
EXPECT_THAT(errors[1].wrong_elements, Eq(482800));
EXPECT_THAT(errors[1].tensor_name, StrEq("incorrect 1 - implicit tolerance"));
EXPECT_THAT(errors[1].wrong_elements, Eq(46840334));
EXPECT_THAT(errors[2].tensor_name, StrEq("incorrect 2 - explicit tolerance"));
EXPECT_THAT(errors[2].wrong_elements, Eq(482800));
EXPECT_THAT(errors[3].tensor_name, StrEq("incorrect 2 - implicit tolerance"));
EXPECT_THAT(errors[3].wrong_elements, Eq(482800));
}
// MatchesReference operates on the types defined in testing.hpp, so just
@@ -234,7 +253,7 @@ ValidationReport validate<DUMMY_SIGNATURE>(const Args<DUMMY_SIGNATURE>& args,
{
ValidationReport report;
report.check("a", args.make_a_descriptor(), actual.a, expected.a);
report.check("b", args.make_b_descriptor(), actual.b, expected.b);
report.check_by_accumulations("b", args.make_b_descriptor(), actual.b, expected.b, 0);
return report;
}
@@ -299,5 +318,5 @@ TEST(MatchesReference, Incorrect)
EXPECT_THAT(listener.str(),
StringEqWithDiff( //
"1 tensors failed to validate\n"
" - a: 625/625 incorrect elements (~100%)"));
" - a: 625/625 incorrect elements (~100%), max error 1"));
}