mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-20 23:27:39 +00:00
* ck-builder: tensor copy function This function copies one tensor to another, so that the memory layout can be changed between them. * ck-builder: fix ck::bhalf literals These types don't work properly. * ck-builder: abstract compare_elements in gpu_verification.hpp and make builder use it This reduces the amount of duplicated code a bit. * ck-builder: add flat tensor iterator This "iterator" type pretends to be a pointer, useful for passing tensors to functions expecting pointer-like types. * ck-builder: integrate validation with ck gpu verification By templating the gpu_verify function over iterators, we can use the new FlatTensorIterator to adapt the function to multi- dimensional tensors without changing either implementation too much. * ck-builder: add check_by_accumulations This changes the gpu_verification.hpp code to also accept "iterator" types for the relevant gpu_verify and gpu_reduce_max functions. * ck: fix test_gpu_verification GenerateRandomData for bhalf is_integer_it<bhalf_t> yields true, but it is not actually an integer. * ck: make gpu_verification kernels be proper persistent kernels Previously these were using a hardcoded value for the grid size. This commit changes that so that the grid size is automatically derived from the kernel's occupancy and the number of multiprocessors on the GPU. * ck: clean up gpu_verification.hpp using block_reduce This implements a small generic block reduce function, and rewrites the rest of gpu_verification.hpp using that function to clean it up a bit. * ck-builder: doc typos * ck-builder: update testing readme with validation interface. * ck-builder: rebase fixes + review comments * ck-builder: fix device integer generation with float types Passing bfloat here causes a nans due to type_convert performing a bitcast. * ck: another bhalf_t bug CK expects that int-generation with ck::bhalf_t yields bhalf integers, not unsigned integers. This makes the logic of FillUniformRandInteger compatible with GeneratorTensor_2<InDataType>, however idiotic that may be.
130 lines
4.5 KiB
C++
130 lines
4.5 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#include "impl/conv_signature_types.hpp"
|
|
#include "testing_utils.hpp"
|
|
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
|
#include "ck_tile/builder/testing/tensor_foreach.hpp"
|
|
#include <gtest/gtest.h>
|
|
#include <gmock/gmock.h>
|
|
#include <vector>
|
|
|
|
namespace ckb = ck_tile::builder;
|
|
namespace ckt = ck_tile::builder::test;
|
|
|
|
using ::testing::ElementsAreArray;
|
|
using ::testing::Eq;
|
|
using ::testing::NotNull;
|
|
|
|
constexpr auto SIGNATURE =
|
|
ckt::ConvSignature{.spatial_dim = 2,
|
|
.direction = ckb::ConvDirection::FORWARD,
|
|
.data_type = ckb::DataType::BF16,
|
|
.accumulation_data_type = ckb::DataType::FP32,
|
|
.input = {.config = {.layout = ckb::TensorLayout::NHWGC}},
|
|
.weight = {.config = {.layout = ckb::TensorLayout::GKYXC}},
|
|
.output = {.config = {.layout = ckb::TensorLayout::NHWGK}}};
|
|
|
|
constexpr ckt::Args<SIGNATURE> ARGS = {
|
|
.lengths =
|
|
{
|
|
.batch_size = 17,
|
|
.groups = 5,
|
|
.input_channels = 13,
|
|
.output_channels = 44,
|
|
.image =
|
|
{
|
|
.width = 99,
|
|
.height = 125,
|
|
},
|
|
.filter =
|
|
{
|
|
.width = 9,
|
|
.height = 4,
|
|
},
|
|
},
|
|
.filter_strides = {.width = 1, .height = 1},
|
|
.filter_dilation = {.width = 1, .height = 1},
|
|
.input_left_pad = {.width = 0, .height = 0},
|
|
.input_right_pad = {.width = 0, .height = 0},
|
|
.a_elementwise_op = {},
|
|
.b_elementwise_op = {},
|
|
.cde_elementwise_op = {},
|
|
};
|
|
|
|
using Inputs = ckt::Inputs<SIGNATURE>;
|
|
using Outputs = ckt::Outputs<SIGNATURE>;
|
|
using UniqueInputs = ckt::UniqueInputs<SIGNATURE>;
|
|
using UniqueOutputs = ckt::UniqueOutputs<SIGNATURE>;
|
|
|
|
static_assert(ckt::ValidUniqueInputs<SIGNATURE>);
|
|
static_assert(ckt::ValidUniqueOutputs<SIGNATURE>);
|
|
static_assert(ckt::TensorReflectable<Inputs, SIGNATURE>);
|
|
static_assert(ckt::TensorReflectable<Outputs, SIGNATURE>);
|
|
|
|
TEST(ConvFwdTesting, MakeDescriptors)
|
|
{
|
|
const auto get_lengths = [](const auto& descriptor) {
|
|
const auto lengths = descriptor.get_lengths();
|
|
// Google Test cannot print std::span, so turn it into a vector for
|
|
// legibility.
|
|
return std::vector(lengths.begin(), lengths.end());
|
|
};
|
|
|
|
EXPECT_THAT(get_lengths(ARGS.make_input_descriptor()), ElementsAreArray({5, 17, 13, 125, 99}));
|
|
EXPECT_THAT(get_lengths(ARGS.make_weight_descriptor()), ElementsAreArray({5, 44, 13, 4, 9}));
|
|
EXPECT_THAT(get_lengths(ARGS.make_output_descriptor()), ElementsAreArray({5, 17, 44, 122, 91}));
|
|
}
|
|
|
|
TEST(ConvFwdTesting, Alloc)
|
|
{
|
|
auto inputs = alloc_inputs(ARGS);
|
|
auto outputs = alloc_outputs(ARGS);
|
|
|
|
EXPECT_THAT(inputs.get().input, NotNull());
|
|
EXPECT_THAT(inputs.get().weight, NotNull());
|
|
EXPECT_THAT(outputs.get().output, NotNull());
|
|
}
|
|
|
|
TEST(ConvFwdTesting, Validate)
|
|
{
|
|
auto a = alloc_outputs(ARGS);
|
|
auto b = alloc_outputs(ARGS);
|
|
|
|
// Positive test
|
|
{
|
|
ckt::Outputs<SIGNATURE>::reflect(
|
|
ARGS,
|
|
[&]([[maybe_unused]] std::string_view name,
|
|
const auto& desc,
|
|
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
|
ckt::clear_tensor_buffer(
|
|
desc, a.get().*ptr, ck::type_convert<ck::bhalf_t, float>(123));
|
|
ckt::clear_tensor_buffer(
|
|
desc, b.get().*ptr, ck::type_convert<ck::bhalf_t, float>(123));
|
|
});
|
|
|
|
const auto report = ckt::validate(ARGS, a.get(), b.get());
|
|
EXPECT_THAT(report.get_errors().size(), Eq(0));
|
|
}
|
|
|
|
// Negative test
|
|
{
|
|
size_t field_count = 0;
|
|
ckt::Outputs<SIGNATURE>::reflect(
|
|
ARGS,
|
|
[&]([[maybe_unused]] std::string_view name,
|
|
const auto& desc,
|
|
void* ckt::Outputs<SIGNATURE>::*ptr) {
|
|
++field_count;
|
|
ckt::clear_tensor_buffer(
|
|
desc, a.get().*ptr, ck::type_convert<ck::bhalf_t, float>(2));
|
|
ckt::clear_tensor_buffer(
|
|
desc, b.get().*ptr, ck::type_convert<ck::bhalf_t, float>(1));
|
|
});
|
|
|
|
const auto report = ckt::validate(ARGS, a.get(), b.get());
|
|
EXPECT_THAT(report.get_errors().size(), Eq(field_count));
|
|
}
|
|
}
|