mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_BUILDER] Integrate reference conv with testing (#3511)
* ck-builder: explicitly delete forward declarations Before, these functions were seen as a forward declaration for an existing function. If no actual implementation overload could be found, these would be selected and a linker error or warning would be generated. By marking these functions as explicitly deleted, they incorrect invocations are generated as compile error instead. * ck-builder: ckt::run plumbing for reference conv This implements the ckt::run plumbing for the reference convolution implementation and sets up the first complete end-to-end test. * ck-builder: make validation system check for all-zeros When both the actual and reference output are both all zero bits, there is probably something wrong in the test framework. * ck-builder: proper implementation+tests for TensorDescriptor::is_packed * ck-builder: fix typos
This commit is contained in:
@@ -125,9 +125,9 @@ struct ReferenceFactory
|
||||
|
||||
// Direct Run method (simpler interface, direction-agnostic)
|
||||
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
|
||||
static void Run(InPtrType input,
|
||||
WeiPtrType weight,
|
||||
OutPtrType output,
|
||||
static void Run(InPtrType* input,
|
||||
WeiPtrType* weight,
|
||||
OutPtrType* output,
|
||||
int G,
|
||||
int N,
|
||||
int K,
|
||||
@@ -142,9 +142,9 @@ struct ReferenceFactory
|
||||
if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
ck_tile::naive_grouped_conv_fwd<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
|
||||
input,
|
||||
weight,
|
||||
output,
|
||||
static_cast<const InDataType*>(input),
|
||||
static_cast<const WeiDataType*>(weight),
|
||||
static_cast<OutDataType*>(output),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -160,9 +160,9 @@ struct ReferenceFactory
|
||||
{
|
||||
ck_tile::
|
||||
naive_grouped_conv_bwd_data<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
|
||||
input,
|
||||
weight,
|
||||
output,
|
||||
static_cast<InDataType*>(input),
|
||||
static_cast<const WeiDataType*>(weight),
|
||||
static_cast<const OutDataType*>(output),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
@@ -179,19 +179,20 @@ struct ReferenceFactory
|
||||
ck_tile::naive_grouped_conv_bwd_weight<SPATIAL_DIM,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(input,
|
||||
weight,
|
||||
output,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial,
|
||||
filter_spatial,
|
||||
output_spatial,
|
||||
strides,
|
||||
dilations,
|
||||
left_pads);
|
||||
OutDataType>(
|
||||
static_cast<const InDataType*>(input),
|
||||
static_cast<WeiDataType*>(weight),
|
||||
static_cast<const OutDataType*>(output),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial,
|
||||
filter_spatial,
|
||||
output_spatial,
|
||||
strides,
|
||||
dilations,
|
||||
left_pads);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <span>
|
||||
#include <cstddef>
|
||||
|
||||
#include "ck_tile/builder/testing/conv_fwd.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include <type_traits>
|
||||
#include <array>
|
||||
|
||||
/// This file contains the implementation details for invoking/testing
|
||||
/// grouped convolution operations in old CK. The main item is the
|
||||
@@ -15,6 +15,63 @@
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// implementation.
|
||||
///
|
||||
/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except
|
||||
/// with some utility aliases. For that reason, its moved to this detail
|
||||
/// namespace.
|
||||
template <typename Conv,
|
||||
auto SIGNATURE,
|
||||
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
typename Ops = factory::internal::ElementwiseOps<SIGNATURE>>
|
||||
concept CkConvInstance = requires(Conv& conv,
|
||||
// TODO: This should be changed depending on IsMultiA etc.
|
||||
// Currently that is not yet supported elsewhere anyway.
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_e,
|
||||
std::array<index_t, SPATIAL_DIM + 3> lengths,
|
||||
std::array<index_t, SPATIAL_DIM + 3> strides,
|
||||
std::array<index_t, SPATIAL_DIM> filter,
|
||||
Ops::AElementwiseOp elementwise_a,
|
||||
Ops::BElementwiseOp elementwise_b,
|
||||
Ops::CDEElementwiseOp elementwise_cde) {
|
||||
{
|
||||
conv.MakeArgument(p_a,
|
||||
p_b,
|
||||
// TODO: Support multiple D outputs.
|
||||
{},
|
||||
p_e,
|
||||
// A lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// B lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// TODO: Ds lengths/strides
|
||||
{},
|
||||
{},
|
||||
// E lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// strides/dilations/pads
|
||||
filter,
|
||||
filter,
|
||||
filter,
|
||||
filter,
|
||||
// element-wise operations.
|
||||
elementwise_a,
|
||||
elementwise_b,
|
||||
elementwise_cde)
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Concept for checking whether a convolution is invoked like old CK.
|
||||
///
|
||||
/// This concept is used to tell whether a convolution implementation is
|
||||
@@ -24,13 +81,8 @@ namespace ck_tile::builder::test {
|
||||
///
|
||||
/// - SIGNATURE is the operation signature.
|
||||
/// - Conv is a convolution instance created by the CK Builder API.
|
||||
template <auto SIGNATURE, typename Conv>
|
||||
concept IsCkConvInstance =
|
||||
// TODO: This should be implemented by converting the signature into the
|
||||
// type parameters for DeviceGroupedConvFwdMultipleABD. For now, just leave
|
||||
// it empty. Improve when needed, you get the point. Also we should probably
|
||||
// move this to the ck conv factory helper.
|
||||
true;
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept CkConvInstance = detail::CkConvInstance<Conv, SIGNATURE>;
|
||||
|
||||
/// @brief `run()` specialization for forward convolution and old CK.
|
||||
///
|
||||
@@ -39,10 +91,9 @@ concept IsCkConvInstance =
|
||||
/// operation. This should be caught and reported by the testing framework.
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE, typename Conv>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE> &&
|
||||
IsCkConvInstance<SIGNATURE, Conv>
|
||||
void run(Conv& conv,
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
void run(CkConvInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
|
||||
@@ -0,0 +1,114 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/conv_fwd.hpp"
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
/// This file contains the implementation details for invoking/testing
|
||||
/// grouped convolution operations using the reference implementation.
|
||||
/// The main item is the `run()` function, which is the primary way to
|
||||
/// invoke the reference execution mechanism.
|
||||
/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`,
|
||||
/// but its made specific to the reference implementation, which is
|
||||
/// invoked in a slightly different way.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// implementation.
|
||||
///
|
||||
/// This concept is used to tell whether a convolution implementation is
|
||||
/// likely to be the reference implementation - that is, whether we should
|
||||
/// invoke it like the reference kernel. This is mainly used with `run()` to
|
||||
/// differentiate which implementation that should be invoked.
|
||||
///
|
||||
/// - SIGNATURE is the operation signature.
|
||||
/// - Conv is a convolution instance created by the CK Builder API.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept RefConvInstance = requires(Conv& conv,
|
||||
const void* input,
|
||||
const void* weight,
|
||||
void* output,
|
||||
int G,
|
||||
int N,
|
||||
int K,
|
||||
int C,
|
||||
std::vector<long_index_t> dims) {
|
||||
{
|
||||
conv.Run(input,
|
||||
weight,
|
||||
output,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
dims, // input_spatial
|
||||
dims, // filter_spatial
|
||||
dims, // output_spatial
|
||||
dims, // strides
|
||||
dims, // dilations
|
||||
dims // left_pads
|
||||
)
|
||||
};
|
||||
};
|
||||
|
||||
/// @brief `run()` specialization for forward convolution and the reference
|
||||
/// implementation.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
/// @throws std::runtime_error if the arguments weren't actually valid for the
|
||||
/// operation. This should be caught and reported by the testing framework.
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> &&
|
||||
// TODO: Maybe we can unify this implementation for bwd/weight too?
|
||||
// for now, just concern outselves with reference and see when the
|
||||
// rest of the bwd/weight plumbing is there.
|
||||
ConvDirectionIsForward<SIGNATURE>
|
||||
void run(RefConvInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
{
|
||||
// We don't want to compute the output dims manually, just get
|
||||
// them via the existing infrastructure
|
||||
const auto param = args.to_ck_conv_param();
|
||||
|
||||
// TODO: The reference convolution is currently missing a few features.
|
||||
// Just throw for now, but regard these as TODO items that should be resolved
|
||||
// eventually.
|
||||
|
||||
// Right pads are not supported right now for some reason.
|
||||
for(auto right_pad : param.input_right_pads_)
|
||||
{
|
||||
if(right_pad != 0)
|
||||
throw std::runtime_error("TODO: Support right pad in reference conv");
|
||||
}
|
||||
|
||||
if(!args.make_input_descriptor().is_packed())
|
||||
throw std::runtime_error("TODO: Support non-packed input tensor in reference conv");
|
||||
if(!args.make_weight_descriptor().is_packed())
|
||||
throw std::runtime_error("TODO: Support non-packed weight tensor in reference conv");
|
||||
if(!args.make_output_descriptor().is_packed())
|
||||
throw std::runtime_error("TODO: Support non-packed output tensor in reference conv");
|
||||
|
||||
conv.Run(inputs.input,
|
||||
inputs.weight,
|
||||
outputs.output,
|
||||
param.G_,
|
||||
param.N_,
|
||||
param.K_,
|
||||
param.C_,
|
||||
param.input_spatial_lengths_,
|
||||
param.filter_spatial_lengths_,
|
||||
param.output_spatial_lengths_,
|
||||
param.conv_filter_strides_,
|
||||
param.conv_filter_dilations_,
|
||||
param.input_left_pads_);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -8,6 +8,7 @@
|
||||
#include <vector>
|
||||
#include <sstream>
|
||||
#include <concepts>
|
||||
#include <algorithm>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/testing/type_traits.hpp"
|
||||
@@ -369,6 +370,35 @@ struct TensorDescriptor
|
||||
return get_element_space_size() * data_type_sizeof(DT);
|
||||
}
|
||||
|
||||
/// @brief Check if a tensor is packed in memory.
|
||||
///
|
||||
/// This function checks whether the tensor memory is "packed", that is, whether
|
||||
/// all elements are continuous in memory with no gaps.
|
||||
bool is_packed() const
|
||||
{
|
||||
// First sort by stride, then check if they match the scan of the
|
||||
// sizes.
|
||||
const auto& lengths = inner_descriptor_.get_lengths();
|
||||
const auto& strides = inner_descriptor_.get_strides();
|
||||
|
||||
std::array<size_t, RANK> indices;
|
||||
std::iota(indices.begin(), indices.end(), 0);
|
||||
std::sort(indices.begin(), indices.end(), [&](auto i, auto j) {
|
||||
return strides[i] < strides[j];
|
||||
});
|
||||
|
||||
size_t x = 1;
|
||||
for(size_t i = 0; i < RANK; ++i)
|
||||
{
|
||||
if(strides[indices[i]] != x)
|
||||
return false;
|
||||
|
||||
x *= lengths[indices[i]];
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// @brief Get a tensor descriptor for the space backing a tensor.
|
||||
///
|
||||
/// This function returns a tensor descriptor which represents the buffer space
|
||||
|
||||
@@ -220,10 +220,13 @@ UniqueInputs<SIGNATURE> alloc_inputs(const Args<SIGNATURE>& args);
|
||||
/// @param args The run-time arguments of the operation.
|
||||
/// @param inputs The operation inputs to initialize with random data.
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @see Inputs
|
||||
/// @see tensor_initialization
|
||||
template <auto SIGNATURE>
|
||||
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs);
|
||||
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs) = delete;
|
||||
|
||||
/// @brief Allocate outputs corresponding to a signature.
|
||||
///
|
||||
@@ -236,13 +239,16 @@ void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs);
|
||||
///
|
||||
/// @param args The run-time arguments of the operation.
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @see Outputs
|
||||
/// @see UniqueOutputs
|
||||
/// @see alloc_buffer()
|
||||
/// @see alloc_tensor_buffer()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidUniqueOutputs<SIGNATURE>
|
||||
UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args);
|
||||
UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args) = delete;
|
||||
|
||||
/// @brief Compare device operation outputs.
|
||||
///
|
||||
@@ -262,10 +268,14 @@ UniqueInputs<SIGNATURE> alloc_outputs(const Args<SIGNATURE>& args);
|
||||
/// @param actual The actual results, the results of the operation to-be-tested.
|
||||
/// @param expected The expected results, the results of the reference implementation.
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @see ValidationReport
|
||||
template <auto SIGNATURE>
|
||||
ValidationReport
|
||||
validate(const Args<SIGNATURE>& args, Outputs<SIGNATURE> actual, Outputs<SIGNATURE> expected);
|
||||
ValidationReport validate(const Args<SIGNATURE>& args,
|
||||
Outputs<SIGNATURE> actual,
|
||||
Outputs<SIGNATURE> expected) = delete;
|
||||
|
||||
/// @brief Invoke a device operation created by CK Builder.
|
||||
///
|
||||
@@ -296,10 +306,13 @@ validate(const Args<SIGNATURE>& args, Outputs<SIGNATURE> actual, Outputs<SIGNATU
|
||||
/// @param inputs The input tensor data. Will not be modified by this function.
|
||||
/// @param outputs The output tensor data. The contents will be overwritten by
|
||||
/// this function.
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
template <auto SIGNATURE, typename Operation>
|
||||
void run(Operation& operation,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs);
|
||||
const Outputs<SIGNATURE>& outputs) = delete;
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <functional>
|
||||
#include <bit>
|
||||
|
||||
/// This file implements functionality related to "validation", ie, functionality
|
||||
/// to compare tensors. The functionality in this file should be testing-framework
|
||||
@@ -48,12 +49,22 @@ struct ValidationReport
|
||||
/// The total number of elements in each tensor.
|
||||
uint64_t total_elements;
|
||||
|
||||
/// The number of elements which were bitwise 0.
|
||||
uint64_t zero_elements;
|
||||
|
||||
/// @brief Check whether both the output and reference tensor were both all zeros.
|
||||
///
|
||||
/// If both tensors are all zero, it indicates either an incorrect testing setup
|
||||
/// or an issue with the testing framework. For that reason we also consider that
|
||||
/// a failure.
|
||||
bool is_all_zero() const { return zero_elements == total_elements; }
|
||||
|
||||
/// @brief Return whether the check associated to this case was successful.
|
||||
///
|
||||
/// This function returns whether the check associated to this case was successful,
|
||||
/// which is directly derived from checking whether the number of incorrect elements
|
||||
/// was 0.
|
||||
bool is_ok() const { return wrong_elements == 0; }
|
||||
/// was 0 AND whether the tensor was not all zero.
|
||||
bool is_ok() const { return wrong_elements == 0 && !is_all_zero(); }
|
||||
};
|
||||
|
||||
/// @brief Get comparison cases which were incorrect.
|
||||
@@ -123,10 +134,13 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
// Initial pass: count errors
|
||||
|
||||
// Allocate and reset counter
|
||||
auto d_error_count = alloc_buffer(sizeof(uint64_t));
|
||||
check_hip(hipMemset(d_error_count.get(), 0, sizeof(uint64_t)));
|
||||
auto d_counters = alloc_buffer(sizeof(uint64_t) * 2);
|
||||
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2));
|
||||
|
||||
tensor_foreach(descriptor.get_lengths(), [=, error_count = d_error_count.get()](auto index) {
|
||||
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
|
||||
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
|
||||
|
||||
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
|
||||
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
|
||||
|
||||
const auto* actual = static_cast<const CKType*>(actual_data);
|
||||
@@ -137,21 +151,44 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
|
||||
const auto offset = calculate_offset(index, strides);
|
||||
|
||||
const auto o = static_cast<double>(type_convert<float>(actual[offset]));
|
||||
const auto r = static_cast<double>(type_convert<float>(expected[offset]));
|
||||
const auto a = actual[offset];
|
||||
const auto b = expected[offset];
|
||||
|
||||
const auto o = static_cast<double>(type_convert<float>(a));
|
||||
const auto r = static_cast<double>(type_convert<float>(b));
|
||||
const auto err = std::abs(o - r);
|
||||
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
// We expect the number of errors to be very low, so just use an atomic
|
||||
// for now.
|
||||
atomicAdd(reinterpret_cast<uint64_t*>(error_count), 1);
|
||||
atomicAdd(d_error_count, 1);
|
||||
}
|
||||
|
||||
// Now compare the numbers as bitwise too.
|
||||
// Update the counter if they're both zero.
|
||||
using Bytes = std::array<std::byte, sizeof(CKType)>;
|
||||
bool all_zero = true;
|
||||
for(auto x : std::bit_cast<Bytes>(a))
|
||||
{
|
||||
if(x != std::byte{0})
|
||||
all_zero = false;
|
||||
}
|
||||
for(auto x : std::bit_cast<Bytes>(b))
|
||||
{
|
||||
if(x != std::byte{0})
|
||||
all_zero = false;
|
||||
}
|
||||
if(all_zero)
|
||||
{
|
||||
atomicAdd(d_zero_count, 1);
|
||||
}
|
||||
});
|
||||
|
||||
uint64_t error_count = 0;
|
||||
check_hip(
|
||||
hipMemcpy(&error_count, d_error_count.get(), sizeof(uint64_t), hipMemcpyDeviceToHost));
|
||||
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
|
||||
uint64_t zero_count = 0;
|
||||
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
|
||||
|
||||
// TODO: Gather detailed coordinates.
|
||||
|
||||
@@ -159,9 +196,10 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
.tensor_name = std::string(tensor_name),
|
||||
.wrong_elements = error_count,
|
||||
.total_elements = descriptor.get_element_size(),
|
||||
.zero_elements = zero_count,
|
||||
});
|
||||
|
||||
return error_count == 0;
|
||||
return reports_.back().is_ok();
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_ck.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_reference.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
@@ -34,6 +35,8 @@ constexpr auto ALGORITHM = cku::ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xd
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
|
||||
TEST(Fwd2DFp16_CShufV3_GNHWC, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
@@ -81,18 +84,17 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
|
||||
.cde_elementwise_op = {},
|
||||
};
|
||||
|
||||
auto inputs = ckt::alloc_inputs(args);
|
||||
auto outputs = ckt::alloc_outputs(args);
|
||||
auto inputs = ckt::alloc_inputs(args);
|
||||
auto outputs = ckt::alloc_outputs(args);
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
|
||||
ckt::init_inputs(args, inputs.get());
|
||||
|
||||
auto conv = Instance{};
|
||||
ckt::run(conv, args, inputs.get(), outputs.get());
|
||||
|
||||
// TODO: This should be allocated via ckt::alloc_outputs() and
|
||||
// initialized via ckt::run() with the reference implementation
|
||||
// instead.
|
||||
auto reference = outputs.get();
|
||||
auto ref_conv = Reference{};
|
||||
ckt::run(ref_conv, args, inputs.get(), reference.get());
|
||||
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(args, reference));
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
|
||||
}
|
||||
|
||||
@@ -170,3 +170,22 @@ TEST(TensorDescriptor, ExtentFromVector)
|
||||
EXPECT_THAT([] { return ckt::Extent<5>::from_vector(std::vector<size_t>{1, 2}); },
|
||||
Throws<std::runtime_error>());
|
||||
}
|
||||
|
||||
TEST(TensorDescriptor, IsPacked)
|
||||
{
|
||||
constexpr auto dt = ckb::DataType::INT32; // Irrelevant for this test
|
||||
EXPECT_TRUE(
|
||||
ckt::make_descriptor<dt>(ckt::Extent{101, 43, 25, 662, 654}, ckt::PackedLeftLayout{})
|
||||
.is_packed());
|
||||
EXPECT_TRUE(
|
||||
ckt::make_descriptor<dt>(ckt::Extent{5334, 235, 1563, 256, 23}, ckt::PackedRightLayout{})
|
||||
.is_packed());
|
||||
EXPECT_TRUE(ckt::make_descriptor<dt>(ckt::Extent{}, ckt::Extent{}).is_packed());
|
||||
EXPECT_TRUE(
|
||||
ckt::make_descriptor<dt>(ckt::Extent{461, 345, 5, 93}, ckt::Extent{160425, 5, 1, 1725})
|
||||
.is_packed());
|
||||
EXPECT_FALSE(
|
||||
ckt::make_descriptor<dt>(ckt::Extent{10, 11, 12}, ckt::Extent{1, 100, 1100}).is_packed());
|
||||
EXPECT_FALSE(
|
||||
ckt::make_descriptor<dt>(ckt::Extent{30, 20, 10}, ckt::Extent{1, 1, 1}).is_packed());
|
||||
}
|
||||
|
||||
@@ -67,7 +67,7 @@ TYPED_TEST(ValidationReportTests, SingleCorrect)
|
||||
// Generate a sort-of-random looking sequence
|
||||
auto generator = [strides = desc.get_strides()](const auto& index) {
|
||||
const auto flat_index = ckt::calculate_offset(index, strides);
|
||||
return static_cast<float>(flat_index * 10'000'019 % 768'351);
|
||||
return static_cast<float>((flat_index + 1) * 10'000'019 % 768'351);
|
||||
};
|
||||
|
||||
ckt::fill_tensor(desc, a.get(), generator);
|
||||
@@ -110,6 +110,27 @@ TYPED_TEST(ValidationReportTests, SingleIncorrect)
|
||||
EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size()));
|
||||
}
|
||||
|
||||
TYPED_TEST(ValidationReportTests, ZeroIsIncorrect)
|
||||
{
|
||||
const auto desc = TypeParam::get_descriptor();
|
||||
|
||||
auto a = ckt::alloc_tensor_buffer(desc);
|
||||
auto b = ckt::alloc_tensor_buffer(desc);
|
||||
|
||||
ckt::clear_tensor_buffer(desc, a.get());
|
||||
ckt::clear_tensor_buffer(desc, b.get());
|
||||
|
||||
ckt::ValidationReport report;
|
||||
report.check("zero_is_incorrect", desc, b.get(), a.get());
|
||||
|
||||
const auto errors = report.get_errors();
|
||||
ASSERT_THAT(errors.size(), Eq(1));
|
||||
EXPECT_THAT(errors[0].tensor_name, StrEq("zero_is_incorrect"));
|
||||
EXPECT_THAT(errors[0].wrong_elements, Eq(0));
|
||||
EXPECT_THAT(errors[0].total_elements, Eq(desc.get_element_size()));
|
||||
EXPECT_THAT(errors[0].zero_elements, Eq(desc.get_element_size()));
|
||||
}
|
||||
|
||||
TEST(ValidationReportTests, MultipleSomeIncorrect)
|
||||
{
|
||||
ckt::ValidationReport report;
|
||||
|
||||
Reference in New Issue
Block a user