[CK_BUILDER] Integrate reference conv with testing (#3511)

* ck-builder: explicitly delete forward declarations

Before, these functions were seen as a forward declaration for an existing function.
If no actual implementation overload could be found, these would be selected and
a linker error or warning would be generated. By marking these functions as explicitly
deleted, they incorrect invocations are generated as compile error instead.

* ck-builder: ckt::run plumbing for reference conv

This implements the ckt::run plumbing for the reference convolution
implementation and sets up the first complete end-to-end test.

* ck-builder: make validation system check for all-zeros

When both the actual and reference output are both all zero bits,
there is probably something wrong in the test framework.

* ck-builder: proper implementation+tests for TensorDescriptor::is_packed

* ck-builder: fix typos
This commit is contained in:
Robin Voetter
2026-01-06 09:29:06 +01:00
committed by GitHub
parent b78563b3d3
commit 1c433c64ec
9 changed files with 349 additions and 60 deletions

View File

@@ -125,9 +125,9 @@ struct ReferenceFactory
// Direct Run method (simpler interface, direction-agnostic)
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
static void Run(InPtrType input,
WeiPtrType weight,
OutPtrType output,
static void Run(InPtrType* input,
WeiPtrType* weight,
OutPtrType* output,
int G,
int N,
int K,
@@ -142,9 +142,9 @@ struct ReferenceFactory
if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
ck_tile::naive_grouped_conv_fwd<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
input,
weight,
output,
static_cast<const InDataType*>(input),
static_cast<const WeiDataType*>(weight),
static_cast<OutDataType*>(output),
G,
N,
K,
@@ -160,9 +160,9 @@ struct ReferenceFactory
{
ck_tile::
naive_grouped_conv_bwd_data<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
input,
weight,
output,
static_cast<InDataType*>(input),
static_cast<const WeiDataType*>(weight),
static_cast<const OutDataType*>(output),
G,
N,
K,
@@ -179,19 +179,20 @@ struct ReferenceFactory
ck_tile::naive_grouped_conv_bwd_weight<SPATIAL_DIM,
InDataType,
WeiDataType,
OutDataType>(input,
weight,
output,
G,
N,
K,
C,
input_spatial,
filter_spatial,
output_spatial,
strides,
dilations,
left_pads);
OutDataType>(
static_cast<const InDataType*>(input),
static_cast<WeiDataType*>(weight),
static_cast<const OutDataType*>(output),
G,
N,
K,
C,
input_spatial,
filter_spatial,
output_spatial,
strides,
dilations,
left_pads);
}
}

View File

@@ -3,10 +3,10 @@
#pragma once
#include <span>
#include <cstddef>
#include "ck_tile/builder/testing/conv_fwd.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include <type_traits>
#include <array>
/// This file contains the implementation details for invoking/testing
/// grouped convolution operations in old CK. The main item is the
@@ -15,6 +15,63 @@
namespace ck_tile::builder::test {
namespace detail {
/// @brief Concept for checking whether this is the reference convolution
/// implementation.
///
/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except
/// with some utility aliases. For that reason, its moved to this detail
/// namespace.
template <typename Conv,
auto SIGNATURE,
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
// TODO: We shouldn't need to call into an internal namespace here.
typename Ops = factory::internal::ElementwiseOps<SIGNATURE>>
concept CkConvInstance = requires(Conv& conv,
// TODO: This should be changed depending on IsMultiA etc.
// Currently that is not yet supported elsewhere anyway.
const void* p_a,
const void* p_b,
void* p_e,
std::array<index_t, SPATIAL_DIM + 3> lengths,
std::array<index_t, SPATIAL_DIM + 3> strides,
std::array<index_t, SPATIAL_DIM> filter,
Ops::AElementwiseOp elementwise_a,
Ops::BElementwiseOp elementwise_b,
Ops::CDEElementwiseOp elementwise_cde) {
{
conv.MakeArgument(p_a,
p_b,
// TODO: Support multiple D outputs.
{},
p_e,
// A lengths/strides
lengths,
strides,
// B lengths/strides
lengths,
strides,
// TODO: Ds lengths/strides
{},
{},
// E lengths/strides
lengths,
strides,
// strides/dilations/pads
filter,
filter,
filter,
filter,
// element-wise operations.
elementwise_a,
elementwise_b,
elementwise_cde)
};
};
} // namespace detail
/// @brief Concept for checking whether a convolution is invoked like old CK.
///
/// This concept is used to tell whether a convolution implementation is
@@ -24,13 +81,8 @@ namespace ck_tile::builder::test {
///
/// - SIGNATURE is the operation signature.
/// - Conv is a convolution instance created by the CK Builder API.
template <auto SIGNATURE, typename Conv>
concept IsCkConvInstance =
// TODO: This should be implemented by converting the signature into the
// type parameters for DeviceGroupedConvFwdMultipleABD. For now, just leave
// it empty. Improve when needed, you get the point. Also we should probably
// move this to the ck conv factory helper.
true;
template <typename Conv, auto SIGNATURE>
concept CkConvInstance = detail::CkConvInstance<Conv, SIGNATURE>;
/// @brief `run()` specialization for forward convolution and old CK.
///
@@ -39,10 +91,9 @@ concept IsCkConvInstance =
/// operation. This should be caught and reported by the testing framework.
///
/// @see run()
template <auto SIGNATURE, typename Conv>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
IsCkConvInstance<SIGNATURE, Conv>
void run(Conv& conv,
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
void run(CkConvInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs)

View File

@@ -0,0 +1,114 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/testing/conv_fwd.hpp"
#include <stdexcept>
#include <vector>
/// This file contains the implementation details for invoking/testing
/// grouped convolution operations using the reference implementation.
/// The main item is the `run()` function, which is the primary way to
/// invoke the reference execution mechanism.
/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`,
/// but its made specific to the reference implementation, which is
/// invoked in a slightly different way.
namespace ck_tile::builder::test {
/// @brief Concept for checking whether this is the reference convolution
/// implementation.
///
/// This concept is used to tell whether a convolution implementation is
/// likely to be the reference implementation - that is, whether we should
/// invoke it like the reference kernel. This is mainly used with `run()` to
/// differentiate which implementation that should be invoked.
///
/// - SIGNATURE is the operation signature.
/// - Conv is a convolution instance created by the CK Builder API.
template <typename Conv, auto SIGNATURE>
concept RefConvInstance = requires(Conv& conv,
const void* input,
const void* weight,
void* output,
int G,
int N,
int K,
int C,
std::vector<long_index_t> dims) {
{
conv.Run(input,
weight,
output,
G,
N,
K,
C,
dims, // input_spatial
dims, // filter_spatial
dims, // output_spatial
dims, // strides
dims, // dilations
dims // left_pads
)
};
};
/// @brief `run()` specialization for forward convolution and the reference
/// implementation.
///
/// @tparam SIGNATURE Forward convolution signature.
/// @throws std::runtime_error if the arguments weren't actually valid for the
/// operation. This should be caught and reported by the testing framework.
///
/// @see run()
template <auto SIGNATURE>
requires ValidConvSignature<SIGNATURE> &&
// TODO: Maybe we can unify this implementation for bwd/weight too?
// for now, just concern outselves with reference and see when the
// rest of the bwd/weight plumbing is there.
ConvDirectionIsForward<SIGNATURE>
void run(RefConvInstance<SIGNATURE> auto& conv,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs)
{
// We don't want to compute the output dims manually, just get
// them via the existing infrastructure
const auto param = args.to_ck_conv_param();
// TODO: The reference convolution is currently missing a few features.
// Just throw for now, but regard these as TODO items that should be resolved
// eventually.
// Right pads are not supported right now for some reason.
for(auto right_pad : param.input_right_pads_)
{
if(right_pad != 0)
throw std::runtime_error("TODO: Support right pad in reference conv");
}
if(!args.make_input_descriptor().is_packed())
throw std::runtime_error("TODO: Support non-packed input tensor in reference conv");
if(!args.make_weight_descriptor().is_packed())
throw std::runtime_error("TODO: Support non-packed weight tensor in reference conv");
if(!args.make_output_descriptor().is_packed())
throw std::runtime_error("TODO: Support non-packed output tensor in reference conv");
conv.Run(inputs.input,
inputs.weight,
outputs.output,
param.G_,
param.N_,
param.K_,
param.C_,
param.input_spatial_lengths_,
param.filter_spatial_lengths_,
param.output_spatial_lengths_,
param.conv_filter_strides_,
param.conv_filter_dilations_,
param.input_left_pads_);
}
} // namespace ck_tile::builder::test

View File

@@ -8,6 +8,7 @@
#include <vector>
#include <sstream>
#include <concepts>
#include <algorithm>
#include <hip/hip_runtime.h>
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/testing/type_traits.hpp"
@@ -369,6 +370,35 @@ struct TensorDescriptor
return get_element_space_size() * data_type_sizeof(DT);
}
/// @brief Check if a tensor is packed in memory.
///
/// This function checks whether the tensor memory is "packed", that is, whether
/// all elements are continuous in memory with no gaps.
bool is_packed() const
{
// First sort by stride, then check if they match the scan of the
// sizes.
const auto& lengths = inner_descriptor_.get_lengths();
const auto& strides = inner_descriptor_.get_strides();
std::array<size_t, RANK> indices;
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [&](auto i, auto j) {
return strides[i] < strides[j];
});
size_t x = 1;
for(size_t i = 0; i < RANK; ++i)
{
if(strides[indices[i]] != x)
return false;
x *= lengths[indices[i]];
}
return true;
}
/// @brief Get a tensor descriptor for the space backing a tensor.
///
/// This function returns a tensor descriptor which represents the buffer space

View File

@@ -220,10 +220,13 @@ UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args);
/// @param args The run-time arguments of the operation.
/// @param inputs The operation inputs to initialize with random data.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see Inputs
/// @see tensor_initialization
template <auto SIGNATURE>
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs);
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs) = delete;
/// @brief Allocate outputs corresponding to a signature.
///
@@ -236,13 +239,16 @@ void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs);
///
/// @param args The run-time arguments of the operation.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see Outputs
/// @see UniqueOutputs
/// @see alloc_buffer()
/// @see alloc_tensor_buffer()
template <auto SIGNATURE>
requires ValidUniqueOutputs<SIGNATURE>
UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args);
UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args) = delete;
/// @brief Compare device operation outputs.
///
@@ -262,10 +268,14 @@ UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args);
/// @param actual The actual results, the results of the operation to-be-tested.
/// @param expected The expected results, the results of the reference implementation.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see ValidationReport
template <auto SIGNATURE>
ValidationReport
validate(const Args<SIGNATURE>& args, Outputs<SIGNATURE> actual, Outputs<SIGNATURE> expected);
ValidationReport validate(const Args<SIGNATURE>& args,
Outputs<SIGNATURE> actual,
Outputs<SIGNATURE> expected) = delete;
/// @brief Invoke a device operation created by CK Builder.
///
@@ -296,10 +306,13 @@ validate(const Args<SIGNATURE>& args, Outputs<SIGNATURE> actual, Outputs<SIGNATU
/// @param inputs The input tensor data. Will not be modified by this function.
/// @param outputs The output tensor data. The contents will be overwritten by
/// this function.
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
template <auto SIGNATURE, typename Operation>
void run(Operation& operation,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs);
const Outputs<SIGNATURE>& outputs) = delete;
} // namespace ck_tile::builder::test

View File

@@ -13,6 +13,7 @@
#include <vector>
#include <algorithm>
#include <functional>
#include <bit>
/// This file implements functionality related to "validation", ie, functionality
/// to compare tensors. The functionality in this file should be testing-framework
@@ -48,12 +49,22 @@ struct ValidationReport
/// The total number of elements in each tensor.
uint64_t total_elements;
/// The number of elements which were bitwise 0.
uint64_t zero_elements;
/// @brief Check whether both the output and reference tensor were both all zeros.
///
/// If both tensors are all zero, it indicates either an incorrect testing setup
/// or an issue with the testing framework. For that reason we also consider that
/// a failure.
bool is_all_zero() const { return zero_elements == total_elements; }
/// @brief Return whether the check associated to this case was successful.
///
/// This function returns whether the check associated to this case was successful,
/// which is directly derived from checking whether the number of incorrect elements
/// was 0.
bool is_ok() const { return wrong_elements == 0; }
/// was 0 AND whether the tensor was not all zero.
bool is_ok() const { return wrong_elements == 0 && !is_all_zero(); }
};
/// @brief Get comparison cases which were incorrect.
@@ -123,10 +134,13 @@ bool ValidationReport::check(std::string_view tensor_name,
// Initial pass: count errors
// Allocate and reset counter
auto d_error_count = alloc_buffer(sizeof(uint64_t));
check_hip(hipMemset(d_error_count.get(), 0, sizeof(uint64_t)));
auto d_counters = alloc_buffer(sizeof(uint64_t) * 2);
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2));
tensor_foreach(descriptor.get_lengths(), [=, error_count = d_error_count.get()](auto index) {
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
const auto* actual = static_cast<const CKType*>(actual_data);
@@ -137,21 +151,44 @@ bool ValidationReport::check(std::string_view tensor_name,
const auto offset = calculate_offset(index, strides);
const auto o = static_cast<double>(type_convert<float>(actual[offset]));
const auto r = static_cast<double>(type_convert<float>(expected[offset]));
const auto a = actual[offset];
const auto b = expected[offset];
const auto o = static_cast<double>(type_convert<float>(a));
const auto r = static_cast<double>(type_convert<float>(b));
const auto err = std::abs(o - r);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
// We expect the number of errors to be very low, so just use an atomic
// for now.
atomicAdd(reinterpret_cast<uint64_t*>(error_count), 1);
atomicAdd(d_error_count, 1);
}
// Now compare the numbers as bitwise too.
// Update the counter if they're both zero.
using Bytes = std::array<std::byte, sizeof(CKType)>;
bool all_zero = true;
for(auto x : std::bit_cast<Bytes>(a))
{
if(x != std::byte{0})
all_zero = false;
}
for(auto x : std::bit_cast<Bytes>(b))
{
if(x != std::byte{0})
all_zero = false;
}
if(all_zero)
{
atomicAdd(d_zero_count, 1);
}
});
uint64_t error_count = 0;
check_hip(
hipMemcpy(&error_count, d_error_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost));
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
uint64_t zero_count = 0;
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
// TODO: Gather detailed coordinates.
@@ -159,9 +196,10 @@ bool ValidationReport::check(std::string_view tensor_name,
.tensor_name = std::string(tensor_name),
.wrong_elements = error_count,
.total_elements = descriptor.get_element_size(),
.zero_elements = zero_count,
});
return error_count == 0;
return reports_.back().is_ok();
}
} // namespace ck_tile::builder::test

View File

@@ -5,6 +5,7 @@
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/builder/testing/conv_fwd_ck.hpp"
#include "ck_tile/builder/testing/conv_fwd_reference.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "testing_utils.hpp"
@@ -34,6 +35,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
TEST(Fwd2DFp16_CShufV3_GNHWC, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
@@ -81,18 +84,17 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
.cde_elementwise_op = {},
};
auto inputs = ckt::alloc_inputs(args);
auto outputs = ckt::alloc_outputs(args);
auto inputs = ckt::alloc_inputs(args);
auto outputs = ckt::alloc_outputs(args);
auto reference = ckt::alloc_outputs(args);
ckt::init_inputs(args, inputs.get());
auto conv = Instance{};
ckt::run(conv, args, inputs.get(), outputs.get());
// TODO: This should be allocated via ckt::alloc_outputs() and
// initialized via ckt::run() with the reference implementation
// instead.
auto reference = outputs.get();
auto ref_conv = Reference{};
ckt::run(ref_conv, args, inputs.get(), reference.get());
EXPECT_THAT(outputs.get(), MatchesReference(args, reference));
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
}

View File

@@ -170,3 +170,22 @@ TEST(TensorDescriptor, ExtentFromVector)
EXPECT_THAT([] { return ckt::Extent<5>::from_vector(std::vector<size_t>{1, 2}); },
Throws<std::runtime_error>());
}
TEST(TensorDescriptor, IsPacked)
{
constexpr auto dt = ckb::DataType::INT32; // Irrelevant for this test
EXPECT_TRUE(
ckt::make_descriptor<dt>(ckt::Extent{101, 43, 25, 662, 654}, ckt::PackedLeftLayout{})
.is_packed());
EXPECT_TRUE(
ckt::make_descriptor<dt>(ckt::Extent{5334, 235, 1563, 256, 23}, ckt::PackedRightLayout{})
.is_packed());
EXPECT_TRUE(ckt::make_descriptor<dt>(ckt::Extent{}, ckt::Extent{}).is_packed());
EXPECT_TRUE(
ckt::make_descriptor<dt>(ckt::Extent{461, 345, 5, 93}, ckt::Extent{160425, 5, 1, 1725})
.is_packed());
EXPECT_FALSE(
ckt::make_descriptor<dt>(ckt::Extent{10, 11, 12}, ckt::Extent{1, 100, 1100}).is_packed());
EXPECT_FALSE(
ckt::make_descriptor<dt>(ckt::Extent{30, 20, 10}, ckt::Extent{1, 1, 1}).is_packed());
}

View File

@@ -67,7 +67,7 @@ TYPED_TEST(ValidationReportTests, SingleCorrect)
// Generate a sort-of-random looking sequence
auto generator = [strides = desc.get_strides()](const auto& index) {
const auto flat_index = ckt::calculate_offset(index, strides);
return static_cast<float>(flat_index * 10'000'019 % 768'351);
return static_cast<float>((flat_index + 1) * 10'000'019 % 768'351);
};
ckt::fill_tensor(desc, a.get(), generator);
@@ -110,6 +110,27 @@ TYPED_TEST(ValidationReportTests, SingleIncorrect)
EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size()));
}
TYPED_TEST(ValidationReportTests, ZeroIsIncorrect)
{
const auto desc = TypeParam::get_descriptor();
auto a = ckt::alloc_tensor_buffer(desc);
auto b = ckt::alloc_tensor_buffer(desc);
ckt::clear_tensor_buffer(desc, a.get());
ckt::clear_tensor_buffer(desc, b.get());
ckt::ValidationReport report;
report.check("zero_is_incorrect", desc, b.get(), a.get());
const auto errors = report.get_errors();
ASSERT_THAT(errors.size(), Eq(1));
EXPECT_THAT(errors[0].tensor_name, StrEq("zero_is_incorrect"));
EXPECT_THAT(errors[0].wrong_elements, Eq(0));
EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size()));
EXPECT_THAT(errors[0].zero_elements, Eq(desc.get_element_size()));
}
TEST(ValidationReportTests, MultipleSomeIncorrect)
{
ckt::ValidationReport report;