mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_BUILDER] conv bwd weight testing (#3618)
* ck-builder: restructure testing conv In order to prepare for bwd of conv testing, this commit moves some files and types around so that we can reuse ckt::Args for both forward and backwards convolution. * ck-builder: decouple fwd_ck.hpp and fwd_reference.hpp from fwd.hpp This will allow us to more easily include fwd.hpp from backwards definitions, which is required for initializing bwd values. * ck-builder: fix layout of test_ckb_conv_bwd_weight_xdl_cshuffle_v3 Turns out that the supplied layout isn't actually supported... * ck-builder: ck and reference conv integration for bwd weight * ck-builder: ck bwd weight execution test * ck-builder: ckt::run support for ck-tile bwd weight * ck-builder: ck tile bwd weight execution test * ck-builder: extra debug printing in MatchesReference * ck-builder: make ckt::run return RunResult This type is more convenient than std::tuple, as it will allow us to use google test matchers with this in the future. * ck-builder: RunResult matcher Using EXPECT_THAT(..., SuccessfulRun()) will generate a check and a nice error message about how and why running an algorithm failed. * ck-builder: doc fixes * ck-builder: add missing headers
This commit is contained in:
@@ -7,26 +7,25 @@
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck_tile/builder/testing/testing_reflect.hpp"
|
||||
#include "ck_tile/builder/testing/filter_extent.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_buffer.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_initialization.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
|
||||
#include "ck_tile/builder/testing/validation.hpp"
|
||||
#include "ck_tile/host/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
|
||||
|
||||
/// This file implements common functionality for invoking/testing grouped
|
||||
/// forward convolutions created through the CK Builder API. The main item
|
||||
/// of it is the ConvArgs structure - which contains a complete description
|
||||
/// of it is the Args structure - which contains a complete description
|
||||
/// of a convolution operation.
|
||||
///
|
||||
/// It is not intended that this file contains implementation details for
|
||||
/// actually launching a convolution operation. As this can be done
|
||||
/// through different APIs depending on the kernel (CK, CK Tile, or a
|
||||
/// reference implementation), the code dealing with that is split out
|
||||
/// into a separate header for each implementation.
|
||||
/// into a separate header for each implementation. Nor does this file
|
||||
/// deal with details for defining the data types (`Inputs` and `Outputs`)
|
||||
/// for different conv directions, that is also split out into separate
|
||||
/// headers to keep this one small.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
@@ -56,7 +55,7 @@ struct ConvTensorLengths
|
||||
///
|
||||
/// @see Args
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE>
|
||||
struct Args<SIGNATURE>
|
||||
{
|
||||
constexpr static auto SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
@@ -204,53 +203,4 @@ struct Args<SIGNATURE>
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `Inputs` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see Inputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
struct Inputs<SIGNATURE>
|
||||
{
|
||||
void* input;
|
||||
void* weight;
|
||||
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
|
||||
inspect("weight", args.make_weight_descriptor(), &Inputs<SIGNATURE>::weight);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `Outputs` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see Outputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
struct Outputs<SIGNATURE>
|
||||
{
|
||||
void* output;
|
||||
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("output", args.make_output_descriptor(), &Outputs<SIGNATURE>::output);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `init_inputs()` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see alloc_inputs()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
|
||||
{
|
||||
init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f);
|
||||
init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -0,0 +1,71 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/tensor_initialization.hpp"
|
||||
#include "ck_tile/builder/testing/testing_reflect.hpp"
|
||||
#include "ck_tile/builder/testing/conv/args.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/error.hpp"
|
||||
|
||||
/// This file deals with the backward weight-specific details of running grouped
|
||||
/// convolution backwards weight operations. It mainly defines the data
|
||||
/// structures (`Input` and `Output`), initialization, and validation. Note
|
||||
/// that for this operation specifically, many of the operations are
|
||||
/// implemented automatically via testing_reflect.hpp.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
/// @brief `Inputs` specialization for backwards weight convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Backwards weight convolution signature.
|
||||
///
|
||||
/// @see Inputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct Inputs<SIGNATURE>
|
||||
{
|
||||
void* input;
|
||||
void* output;
|
||||
|
||||
// See testing_reflect.hpp
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
|
||||
inspect("output", args.make_output_descriptor(), &Inputs<SIGNATURE>::output);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `Outputs` specialization for backwards weight convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Backwards weight convolution signature.
|
||||
///
|
||||
/// @see Outputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
struct Outputs<SIGNATURE>
|
||||
{
|
||||
void* weight;
|
||||
|
||||
// See testing_reflect.hpp
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("weight", args.make_weight_descriptor(), &Outputs<SIGNATURE>::weight);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `init_inputs()` specialization for backwards convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Backwards weight convolution signature.
|
||||
///
|
||||
/// @see init_inputs()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
|
||||
{
|
||||
init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f);
|
||||
init_tensor_buffer_uniform_fp(inputs.output, args.make_output_descriptor(), -2.0f, 2.0f);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -0,0 +1,276 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include <type_traits>
|
||||
#include <array>
|
||||
|
||||
/// This file contains the implementation details for invoking/testing
|
||||
/// bwd grouped convolution operations in old CK. The main item is the
|
||||
/// `run()` function, which is the main implementation used to invoke
|
||||
/// CK grouped forward convolution kernels.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// @brief Concept for checking whether a bwd weight convolution is invoked like old CK.
|
||||
///
|
||||
/// This is the same as `::ck_tile::builder::test::CkConvBwdWeightInstance`, except
|
||||
/// with some utility aliases. For that reason, its moved to this detail
|
||||
/// namespace.
|
||||
template <typename Conv,
|
||||
auto SIGNATURE,
|
||||
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
typename Types = factory::internal::ConvTensorDataTypes<SIGNATURE>,
|
||||
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
|
||||
concept CkConvBwdWeightInstance = requires(Conv& conv,
|
||||
const Types::InDataType* p_a,
|
||||
Types::WeiDataType* p_b,
|
||||
const Types::OutDataType* p_e,
|
||||
std::array<index_t, SPATIAL_DIM + 3> lengths,
|
||||
std::array<index_t, SPATIAL_DIM + 3> strides,
|
||||
std::array<index_t, SPATIAL_DIM> filter,
|
||||
Ops::InElementwiseOp elementwise_a,
|
||||
Ops::WeiElementwiseOp elementwise_b,
|
||||
Ops::OutElementwiseOp elementwise_cde,
|
||||
ck::index_t split_k) {
|
||||
requires ValidConvSignature<SIGNATURE>;
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>;
|
||||
|
||||
{
|
||||
conv.MakeArgument(p_a,
|
||||
p_b,
|
||||
p_e,
|
||||
// A lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// B lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// E lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// strides/dilations/pads
|
||||
filter,
|
||||
filter,
|
||||
filter,
|
||||
filter,
|
||||
// element-wise operations.
|
||||
elementwise_a,
|
||||
elementwise_b,
|
||||
elementwise_cde,
|
||||
split_k)
|
||||
};
|
||||
};
|
||||
|
||||
/// @brief Concept for checking whether a bwd weight convolution is multiple-d and
|
||||
/// invoked like old CK.
|
||||
///
|
||||
/// This is the same as `::ck_tile::builder::test::CkConvBwdWeightMultipleDInstance`, except
|
||||
/// with some utility aliases. For that reason, its moved to this detail
|
||||
/// namespace.
|
||||
template <typename Conv,
|
||||
auto SIGNATURE,
|
||||
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
typename Types = factory::internal::ConvTensorDataTypes<SIGNATURE>,
|
||||
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
|
||||
concept CkConvBwdWeightMultipleDInstance = requires(Conv& conv,
|
||||
const Types::InDataType* p_a,
|
||||
Types::WeiDataType* p_b,
|
||||
const Types::OutDataType* p_e,
|
||||
std::array<index_t, SPATIAL_DIM + 3> lengths,
|
||||
std::array<index_t, SPATIAL_DIM + 3> strides,
|
||||
std::array<index_t, SPATIAL_DIM> filter,
|
||||
Ops::InElementwiseOp elementwise_a,
|
||||
Ops::WeiElementwiseOp elementwise_b,
|
||||
Ops::OutElementwiseOp elementwise_cde,
|
||||
ck::index_t split_k) {
|
||||
requires ValidConvSignature<SIGNATURE>;
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>;
|
||||
|
||||
{
|
||||
conv.MakeArgument(p_a,
|
||||
p_b,
|
||||
p_e,
|
||||
// TODO: Actually support multiple d
|
||||
{},
|
||||
// A lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// B lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// E lengths/strides
|
||||
lengths,
|
||||
strides,
|
||||
// TODO: Multiple D lengths/strides
|
||||
{},
|
||||
{},
|
||||
// strides/dilations/pads
|
||||
filter,
|
||||
filter,
|
||||
filter,
|
||||
filter,
|
||||
// element-wise operations.
|
||||
elementwise_a,
|
||||
elementwise_b,
|
||||
elementwise_cde,
|
||||
split_k)
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Concept for checking whether a bwd weight convolution is invoked like old CK.
|
||||
///
|
||||
/// - SIGNATURE is the operation signature.
|
||||
/// - Conv is a convolution instance created by the CK Builder API.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept CkConvBwdWeightInstance = detail::CkConvBwdWeightInstance<Conv, SIGNATURE>;
|
||||
|
||||
/// @brief Concept for checking whether a bwd weight convolution is multiple-d and
|
||||
/// invoked like old CK.
|
||||
///
|
||||
/// - SIGNATURE is the operation signature.
|
||||
/// - Conv is a convolution instance created by the CK Builder API.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept CkConvBwdWeightMultipleDInstance =
|
||||
detail::CkConvBwdWeightMultipleDInstance<Conv, SIGNATURE>;
|
||||
|
||||
/// @brief `run()` specialization for backward weight convolution and old CK.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
[[nodiscard]] RunResult run(CkConvBwdWeightInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
{
|
||||
using Types = factory::internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
|
||||
constexpr auto spatial_dim = SIGNATURE.spatial_dim;
|
||||
|
||||
const auto copy = [](const auto& src, auto& dst) {
|
||||
std::copy(src.begin(), src.end(), dst.begin());
|
||||
};
|
||||
|
||||
const auto to_ck_lengths = [&](const auto& src) {
|
||||
std::array<ck::index_t, spatial_dim + 3> result;
|
||||
copy(src, result);
|
||||
return result;
|
||||
};
|
||||
|
||||
const auto to_ck_extent = [&](const auto& extent) {
|
||||
std::array<ck::index_t, spatial_dim> result;
|
||||
copy(extent, result);
|
||||
return result;
|
||||
};
|
||||
|
||||
const auto param = args.to_ck_conv_param();
|
||||
|
||||
const auto input_desc = args.make_input_descriptor();
|
||||
const auto weight_desc = args.make_weight_descriptor();
|
||||
const auto output_desc = args.make_output_descriptor();
|
||||
|
||||
auto ck_args = conv.MakeArgument(static_cast<const Types::InDataType*>(inputs.input),
|
||||
static_cast<Types::WeiDataType*>(outputs.weight),
|
||||
static_cast<const Types::OutDataType*>(inputs.output),
|
||||
to_ck_lengths(input_desc.get_lengths()),
|
||||
to_ck_lengths(input_desc.get_strides()),
|
||||
to_ck_lengths(weight_desc.get_lengths()),
|
||||
to_ck_lengths(weight_desc.get_strides()),
|
||||
to_ck_lengths(output_desc.get_lengths()),
|
||||
to_ck_lengths(output_desc.get_strides()),
|
||||
to_ck_extent(param.conv_filter_strides_),
|
||||
to_ck_extent(param.conv_filter_dilations_),
|
||||
to_ck_extent(param.input_left_pads_),
|
||||
to_ck_extent(param.input_right_pads_),
|
||||
args.a_elementwise_op,
|
||||
args.b_elementwise_op,
|
||||
args.cde_elementwise_op,
|
||||
args.k_batch);
|
||||
|
||||
if(!conv.IsSupportedArgument(ck_args))
|
||||
return RunResult::not_supported("invalid ck arguments");
|
||||
|
||||
return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, {}));
|
||||
}
|
||||
|
||||
/// @brief `run()` specialization for backward weight convolution and old CK.
|
||||
///
|
||||
/// This overload is specialized for Multiple-D.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
[[nodiscard]] RunResult run(CkConvBwdWeightMultipleDInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
{
|
||||
using Types = factory::internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
|
||||
constexpr auto spatial_dim = SIGNATURE.spatial_dim;
|
||||
|
||||
const auto copy = [](const auto& src, auto& dst) {
|
||||
std::copy(src.begin(), src.end(), dst.begin());
|
||||
};
|
||||
|
||||
const auto to_ck_lengths = [&](const auto& src) {
|
||||
std::array<ck::index_t, spatial_dim + 3> result;
|
||||
copy(src, result);
|
||||
return result;
|
||||
};
|
||||
|
||||
const auto to_ck_extent = [&](const auto& extent) {
|
||||
std::array<ck::index_t, spatial_dim> result;
|
||||
copy(extent, result);
|
||||
return result;
|
||||
};
|
||||
|
||||
const auto param = args.to_ck_conv_param();
|
||||
|
||||
const auto input_desc = args.make_input_descriptor();
|
||||
const auto weight_desc = args.make_weight_descriptor();
|
||||
const auto output_desc = args.make_output_descriptor();
|
||||
|
||||
auto ck_args = conv.MakeArgument(static_cast<const Types::InDataType*>(inputs.input),
|
||||
static_cast<Types::WeiDataType*>(outputs.weight),
|
||||
static_cast<const Types::OutDataType*>(inputs.output),
|
||||
{}, // TODO
|
||||
to_ck_lengths(input_desc.get_lengths()),
|
||||
to_ck_lengths(input_desc.get_strides()),
|
||||
to_ck_lengths(weight_desc.get_lengths()),
|
||||
to_ck_lengths(weight_desc.get_strides()),
|
||||
to_ck_lengths(output_desc.get_lengths()),
|
||||
to_ck_lengths(output_desc.get_strides()),
|
||||
{}, // TODO
|
||||
{}, // TODO
|
||||
to_ck_extent(param.conv_filter_strides_),
|
||||
to_ck_extent(param.conv_filter_dilations_),
|
||||
to_ck_extent(param.input_left_pads_),
|
||||
to_ck_extent(param.input_right_pads_),
|
||||
args.a_elementwise_op,
|
||||
args.b_elementwise_op,
|
||||
args.cde_elementwise_op,
|
||||
args.k_batch);
|
||||
|
||||
if(!conv.IsSupportedArgument(ck_args))
|
||||
return RunResult::not_supported("invalid ck arguments");
|
||||
|
||||
return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, {}));
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -3,9 +3,8 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/conv_fwd.hpp"
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/ops/gemm.hpp"
|
||||
#include "ck_tile/ops/grouped_convolution.hpp"
|
||||
#include <type_traits>
|
||||
@@ -28,9 +27,39 @@ namespace detail {
|
||||
/// namespace.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept CkTileConvInstance = requires(Conv&) {
|
||||
requires ValidConvSignature<SIGNATURE>;
|
||||
{ Conv::BlockSize() };
|
||||
};
|
||||
|
||||
template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
InDataType* input,
|
||||
WeiDataType* weight,
|
||||
OutDataType* output,
|
||||
const ck_tile::stream_config s_conf)
|
||||
{
|
||||
using Conv = std::remove_reference_t<decltype(conv)>;
|
||||
const auto param = args.to_ck_tile_conv_param();
|
||||
|
||||
ck_tile::GroupedConvHostArgs<InDataType*, WeiDataType*, OutDataType*, ck_tile::PassThrough>
|
||||
host_args(param, input, weight, {}, output, args.k_batch);
|
||||
|
||||
auto kargs = Conv::MakeKernelArgs(host_args);
|
||||
|
||||
const dim3 grids = Conv::GridSize(kargs);
|
||||
const dim3 blocks = Conv::BlockSize();
|
||||
|
||||
if(!Conv::IsSupportedArgument(kargs))
|
||||
return RunResult::not_supported("unsupported ck_tile arguments");
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
return RunResult::from_runtime(ck_tile::launch_kernel(
|
||||
s_conf, ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Concept for checking whether a convolution is invoked like CK Tile.
|
||||
@@ -48,44 +77,45 @@ concept CkTileConvInstance = detail::CkTileConvInstance<Conv, SIGNATURE>;
|
||||
/// @brief `run()` specialization for forward convolution and CK Tile.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
/// @throws std::runtime_error if the arguments weren't actually valid for the
|
||||
/// operation. This should be caught and reported by the testing framework.
|
||||
/// @return std::tuple<bool, float> - whether the problem is supported and
|
||||
/// kernel execution time (0.0f if s_conf time_kernel is false).
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
std::tuple<bool, float> run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs,
|
||||
const ck_tile::stream_config s_conf = {})
|
||||
{
|
||||
using Conv = std::remove_reference_t<decltype(conv)>;
|
||||
const auto param = args.to_ck_tile_conv_param();
|
||||
return detail::run(conv,
|
||||
args,
|
||||
static_cast<const void*>(inputs.input),
|
||||
static_cast<const void*>(inputs.weight),
|
||||
static_cast<void*>(outputs.output),
|
||||
s_conf);
|
||||
}
|
||||
|
||||
ck_tile::GroupedConvFwdHostArgs<> host_args(
|
||||
param, inputs.input, inputs.weight, {}, outputs.output, args.k_batch);
|
||||
|
||||
auto kargs = Conv::MakeKernelArgs(host_args);
|
||||
|
||||
const dim3 grids = Conv::GridSize(kargs);
|
||||
const dim3 blocks = Conv::BlockSize();
|
||||
|
||||
if(!Conv::IsSupportedArgument(kargs))
|
||||
{
|
||||
std::cout << "Not supported!";
|
||||
return std::make_tuple(false, 0.f);
|
||||
}
|
||||
|
||||
constexpr index_t minimum_occupancy =
|
||||
Conv::GemmPipeline::Scheduler == ck_tile::GemmPipelineScheduler::Intrawave ? 1 : 2;
|
||||
|
||||
return std::make_tuple(
|
||||
true,
|
||||
ck_tile::launch_kernel(
|
||||
s_conf, ck_tile::make_kernel<minimum_occupancy>(conv, grids, blocks, 0, kargs)));
|
||||
/// @brief `run()` specialization for backwards weight convolution and CK Tile.
|
||||
///
|
||||
/// @tparam SIGNATURE Backwards weight convolution signature.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ConvDirectionIsBackwardWeight<SIGNATURE>
|
||||
[[nodiscard]] RunResult run(CkTileConvInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs,
|
||||
const ck_tile::stream_config s_conf = {})
|
||||
{
|
||||
return detail::run(conv,
|
||||
args,
|
||||
static_cast<const void*>(inputs.input),
|
||||
static_cast<void*>(outputs.weight),
|
||||
static_cast<const void*>(inputs.output),
|
||||
s_conf);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -0,0 +1,69 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/tensor_initialization.hpp"
|
||||
#include "ck_tile/builder/testing/testing_reflect.hpp"
|
||||
#include "ck_tile/builder/testing/conv/args.hpp"
|
||||
|
||||
/// This file deals with the forward-specific details of running grouped
|
||||
/// convolution forward operations. It mainly defines the data structures
|
||||
/// (`Input` and `Output`), initialization, and validation. Note that
|
||||
/// for this operation specifically, many of the operations are implemented
|
||||
/// automatically via testing_reflect.hpp.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
/// @brief `Inputs` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see Inputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
struct Inputs<SIGNATURE>
|
||||
{
|
||||
void* input;
|
||||
void* weight;
|
||||
|
||||
// See testing_reflect.hpp
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("input", args.make_input_descriptor(), &Inputs<SIGNATURE>::input);
|
||||
inspect("weight", args.make_weight_descriptor(), &Inputs<SIGNATURE>::weight);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `Outputs` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see Outputs
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
struct Outputs<SIGNATURE>
|
||||
{
|
||||
void* output;
|
||||
|
||||
// See testing_reflect.hpp
|
||||
static void reflect(const Args<SIGNATURE>& args, const auto& inspect)
|
||||
{
|
||||
inspect("output", args.make_output_descriptor(), &Outputs<SIGNATURE>::output);
|
||||
}
|
||||
};
|
||||
|
||||
/// @brief `init_inputs()` specialization for forward convolution.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
///
|
||||
/// @see init_inputs()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
void init_inputs(const Args<SIGNATURE>& args, Inputs<SIGNATURE> inputs)
|
||||
{
|
||||
init_tensor_buffer_uniform_fp(inputs.input, args.make_input_descriptor(), -2.0f, 2.0f);
|
||||
init_tensor_buffer_uniform_fp(inputs.weight, args.make_weight_descriptor(), -2.0f, 2.0f);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -3,14 +3,14 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/conv_fwd.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include <type_traits>
|
||||
#include <array>
|
||||
|
||||
/// This file contains the implementation details for invoking/testing
|
||||
/// grouped convolution operations in old CK. The main item is the
|
||||
/// fwd grouped convolution operations in old CK. The main item is the
|
||||
/// `run()` function, which is the main implementation used to invoke
|
||||
/// CK grouped forward convolution kernels.
|
||||
|
||||
@@ -18,10 +18,9 @@ namespace ck_tile::builder::test {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// implementation.
|
||||
/// @brief Concept for checking whether a fwd convolution is invoked like old CK.
|
||||
///
|
||||
/// This is the same as `::ck_tile::builder::test::CkConvInstance`, except
|
||||
/// This is the same as `::ck_tile::builder::test::CkConvFwdInstance`, except
|
||||
/// with some utility aliases. For that reason, its moved to this detail
|
||||
/// namespace.
|
||||
template <typename Conv,
|
||||
@@ -29,18 +28,21 @@ template <typename Conv,
|
||||
size_t SPATIAL_DIM = SIGNATURE.spatial_dim,
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
typename Ops = factory::internal::ConvElementwiseOps<SIGNATURE>>
|
||||
concept CkConvInstance = requires(Conv& conv,
|
||||
// TODO: This should be changed depending on IsMultiA etc.
|
||||
// Currently that is not yet supported elsewhere anyway.
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_e,
|
||||
std::array<index_t, SPATIAL_DIM + 3> lengths,
|
||||
std::array<index_t, SPATIAL_DIM + 3> strides,
|
||||
std::array<index_t, SPATIAL_DIM> filter,
|
||||
Ops::InElementwiseOp elementwise_a,
|
||||
Ops::WeiElementwiseOp elementwise_b,
|
||||
Ops::OutElementwiseOp elementwise_cde) {
|
||||
concept CkConvFwdInstance = requires(Conv& conv,
|
||||
// TODO: This should be changed depending on IsMultiA etc.
|
||||
// Currently that is not yet supported elsewhere anyway.
|
||||
const void* p_a,
|
||||
const void* p_b,
|
||||
void* p_e,
|
||||
std::array<index_t, SPATIAL_DIM + 3> lengths,
|
||||
std::array<index_t, SPATIAL_DIM + 3> strides,
|
||||
std::array<index_t, SPATIAL_DIM> filter,
|
||||
Ops::InElementwiseOp elementwise_a,
|
||||
Ops::WeiElementwiseOp elementwise_b,
|
||||
Ops::OutElementwiseOp elementwise_cde) {
|
||||
requires ValidConvSignature<SIGNATURE>;
|
||||
requires ConvDirectionIsForward<SIGNATURE>;
|
||||
|
||||
{
|
||||
conv.MakeArgument(p_a,
|
||||
p_b,
|
||||
@@ -73,7 +75,7 @@ concept CkConvInstance = requires(Conv& conv,
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Concept for checking whether a convolution is invoked like old CK.
|
||||
/// @brief Concept for checking whether a fwd convolution is invoked like old CK.
|
||||
///
|
||||
/// This concept is used to tell whether a convolution implementation is
|
||||
/// likely to be an "old CK" implementation - that is, whether we should
|
||||
@@ -83,20 +85,17 @@ concept CkConvInstance = requires(Conv& conv,
|
||||
/// - SIGNATURE is the operation signature.
|
||||
/// - Conv is a convolution instance created by the CK Builder API.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept CkConvInstance = detail::CkConvInstance<Conv, SIGNATURE>;
|
||||
concept CkConvFwdInstance = detail::CkConvFwdInstance<Conv, SIGNATURE>;
|
||||
|
||||
/// @brief `run()` specialization for forward convolution and old CK.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
/// @throws std::runtime_error if the arguments weren't actually valid for the
|
||||
/// operation. This should be caught and reported by the testing framework.
|
||||
/// @return std::tuple<bool, float> - whether the problem is supported and
|
||||
/// kernel execution time (0.0f if s_conf time_kernel is false).
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> && ConvDirectionIsForward<SIGNATURE>
|
||||
std::tuple<bool, float> run(CkConvInstance<SIGNATURE> auto& conv,
|
||||
[[nodiscard]] RunResult run(CkConvFwdInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs,
|
||||
@@ -126,6 +125,9 @@ std::tuple<bool, float> run(CkConvInstance<SIGNATURE> auto& conv,
|
||||
const auto weight_desc = args.make_weight_descriptor();
|
||||
const auto output_desc = args.make_output_descriptor();
|
||||
|
||||
if(args.k_batch != 1)
|
||||
return RunResult::not_supported("ck fwd does not support k_batch != 1");
|
||||
|
||||
auto ck_args = conv.MakeArgument(inputs.input,
|
||||
inputs.weight,
|
||||
{},
|
||||
@@ -147,11 +149,9 @@ std::tuple<bool, float> run(CkConvInstance<SIGNATURE> auto& conv,
|
||||
args.cde_elementwise_op);
|
||||
|
||||
if(!conv.IsSupportedArgument(ck_args))
|
||||
{
|
||||
std::cout << "invalid argument" << std::endl;
|
||||
}
|
||||
return RunResult::not_supported("unsupported ck arguments");
|
||||
|
||||
return std::make_tuple(true, conv.MakeInvoker().Run(ck_args, s_conf));
|
||||
return RunResult::from_runtime(conv.MakeInvoker().Run(ck_args, s_conf));
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -0,0 +1,137 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
/// This file contains the implementation details for invoking/testing
|
||||
/// grouped convolution operations using the reference implementation.
|
||||
/// The main item is the `run()` function, which is the primary way to
|
||||
/// invoke the reference execution mechanism.
|
||||
/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`,
|
||||
/// but its made specific to the reference implementation, which is
|
||||
/// invoked in a slightly different way.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
namespace detail {
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// implementation.
|
||||
///
|
||||
/// This concept is used to tell whether a convolution implementation is
|
||||
/// likely to be the reference implementation - that is, whether we should
|
||||
/// invoke it like the reference kernel. This is mainly used with `run()` to
|
||||
/// differentiate which implementation that should be invoked.
|
||||
///
|
||||
/// - SIGNATURE is the operation signature.
|
||||
/// - Conv is a convolution instance created by the CK Builder API.
|
||||
/// - InDataType, WeiDataType, OutDataType are the types of the respective tensors.
|
||||
template <typename Conv,
|
||||
auto SIGNATURE,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType>
|
||||
concept RefConvInstance = requires(Conv& conv,
|
||||
InDataType* input,
|
||||
WeiDataType* weight,
|
||||
OutDataType* output,
|
||||
ck::utils::conv::ConvParam param) {
|
||||
requires ValidConvSignature<SIGNATURE>;
|
||||
{ conv.Run(input, weight, output, param) };
|
||||
};
|
||||
|
||||
/// @brief Generic `run` implementation for forward/backwards reference kernels.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature of the operation to perform.
|
||||
///
|
||||
/// @return std::tuple<bool, float> - whether the problem is supported and
|
||||
/// kernel execution time (0.0f for reference).
|
||||
/// @see run()
|
||||
template <auto SIGNATURE, typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
[[nodiscard]] RunResult
|
||||
run(RefConvInstance<SIGNATURE, InDataType, WeiDataType, OutDataType> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
InDataType* input,
|
||||
WeiDataType* weight,
|
||||
OutDataType* output)
|
||||
{
|
||||
// We don't want to compute the output dims manually, just get
|
||||
// them via the existing infrastructure
|
||||
const auto param = args.to_ck_conv_param();
|
||||
|
||||
// TODO: The reference convolution is currently missing a few features.
|
||||
// Just throw for now, but regard these as TODO items that should be resolved
|
||||
// eventually.
|
||||
|
||||
if(!args.make_input_descriptor().is_packed())
|
||||
return RunResult::not_supported("TODO: Support non-packed input tensor in reference conv");
|
||||
|
||||
if(!args.make_weight_descriptor().is_packed())
|
||||
return RunResult::not_supported("TODO: Support non-packed weight tensor in reference conv");
|
||||
|
||||
if(!args.make_output_descriptor().is_packed())
|
||||
return RunResult::not_supported("TODO: Support non-packed output tensor in reference conv");
|
||||
|
||||
conv.Run(input, weight, output, param);
|
||||
return RunResult::from_runtime(0); // ref conv does not return a meaningful runtime.
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// forward implementation.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept RefConvFwdInstance =
|
||||
detail::RefConvInstance<Conv, SIGNATURE, const void*, const void*, void*> &&
|
||||
ConvDirectionIsForward<SIGNATURE>;
|
||||
|
||||
/// @brief `run()` specialization for forward convolution and the reference
|
||||
/// forward implementation.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature of the operation to perform. Must be forwards.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> &&
|
||||
// TODO: Maybe we can unify this implementation for bwd/weight too?
|
||||
// for now, just concern outselves with reference and see when the
|
||||
// rest of the bwd/weight plumbing is there.
|
||||
ConvDirectionIsForward<SIGNATURE>
|
||||
[[nodiscard]] RunResult run(RefConvFwdInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
{
|
||||
return detail::run(conv, args, inputs.input, inputs.weight, outputs.output);
|
||||
}
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// backward weight implementation.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept RefConvBwdWeightInstance =
|
||||
detail::RefConvInstance<Conv, SIGNATURE, const void*, void*, const void*> &&
|
||||
ConvDirectionIsBackwardWeight<SIGNATURE>;
|
||||
|
||||
/// @brief `run()` specialization for forward convolution and the reference
|
||||
/// backward weight implementation.
|
||||
///
|
||||
/// @tparam SIGNATURE The signature of the operation to perform. Must be backwards weight.
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
[[nodiscard]] RunResult run(RefConvBwdWeightInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
{
|
||||
return detail::run(conv, args, inputs.input, outputs.weight, inputs.output);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -1,88 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/testing/conv_fwd.hpp"
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
/// This file contains the implementation details for invoking/testing
|
||||
/// grouped convolution operations using the reference implementation.
|
||||
/// The main item is the `run()` function, which is the primary way to
|
||||
/// invoke the reference execution mechanism.
|
||||
/// The implementation of this file mostly looks like `conv_fwd_ck.hpp`,
|
||||
/// but its made specific to the reference implementation, which is
|
||||
/// invoked in a slightly different way.
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
/// @brief Concept for checking whether this is the reference convolution
|
||||
/// implementation.
|
||||
///
|
||||
/// This concept is used to tell whether a convolution implementation is
|
||||
/// likely to be the reference implementation - that is, whether we should
|
||||
/// invoke it like the reference kernel. This is mainly used with `run()` to
|
||||
/// differentiate which implementation that should be invoked.
|
||||
///
|
||||
/// - SIGNATURE is the operation signature.
|
||||
/// - Conv is a convolution instance created by the CK Builder API.
|
||||
template <typename Conv, auto SIGNATURE>
|
||||
concept RefConvInstance = requires(Conv& conv,
|
||||
const void* input,
|
||||
const void* weight,
|
||||
void* output,
|
||||
ck::utils::conv::ConvParam param) {
|
||||
{ conv.Run(input, weight, output, param) };
|
||||
};
|
||||
|
||||
/// @brief `run()` specialization for forward convolution and the reference
|
||||
/// implementation.
|
||||
///
|
||||
/// @tparam SIGNATURE Forward convolution signature.
|
||||
/// @throws std::runtime_error if the arguments weren't actually valid for the
|
||||
/// operation. This should be caught and reported by the testing framework.
|
||||
///
|
||||
/// @return std::tuple<bool, float> - whether the problem is supported and
|
||||
/// kernel execution time (0.0f for reference).
|
||||
/// @see run()
|
||||
template <auto SIGNATURE>
|
||||
requires ValidConvSignature<SIGNATURE> &&
|
||||
// TODO: Maybe we can unify this implementation for bwd/weight too?
|
||||
// for now, just concern outselves with reference and see when the
|
||||
// rest of the bwd/weight plumbing is there.
|
||||
ConvDirectionIsForward<SIGNATURE>
|
||||
std::tuple<bool, float> run(RefConvInstance<SIGNATURE> auto& conv,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs)
|
||||
{
|
||||
// We don't want to compute the output dims manually, just get
|
||||
// them via the existing infrastructure
|
||||
const auto param = args.to_ck_conv_param();
|
||||
|
||||
// TODO: The reference convolution is currently missing a few features.
|
||||
// Just throw for now, but regard these as TODO items that should be resolved
|
||||
// eventually.
|
||||
|
||||
if(!args.make_input_descriptor().is_packed())
|
||||
{
|
||||
std::cout << "TODO: Support non-packed input tensor in reference conv" << std::endl;
|
||||
return std::make_tuple(false, 0.0f);
|
||||
}
|
||||
if(!args.make_weight_descriptor().is_packed())
|
||||
{
|
||||
std::cout << "TODO: Support non-packed weight tensor in reference conv" << std::endl;
|
||||
return std::make_tuple(false, 0.0f);
|
||||
}
|
||||
if(!args.make_output_descriptor().is_packed())
|
||||
{
|
||||
std::cout << "TODO: Support non-packed output tensor in reference conv" << std::endl;
|
||||
return std::make_tuple(false, 0.0f);
|
||||
}
|
||||
|
||||
conv.Run(inputs.input, inputs.weight, outputs.output, param);
|
||||
return std::make_tuple(true, 0.0f);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
@@ -12,6 +12,7 @@
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/testing/type_traits.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include "ck/utility/data_type.hpp"
|
||||
|
||||
|
||||
@@ -3,7 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
#include <string>
|
||||
#include <iosfwd>
|
||||
|
||||
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_buffer.hpp"
|
||||
@@ -288,6 +292,57 @@ ValidationReport validate(const Args<SIGNATURE>& args,
|
||||
Outputs<SIGNATURE> actual,
|
||||
Outputs<SIGNATURE> expected) = delete;
|
||||
|
||||
/// @brief This structure represents the result of a run operation.
|
||||
///
|
||||
/// The structure contains multiple fields with information about
|
||||
/// how the operation completed (or not). See those for more info.
|
||||
struct RunResult
|
||||
{
|
||||
/// If this value is not set to `std::nullopt`, there was a problem
|
||||
/// while running the algorithm. In this case, the outputs are not
|
||||
/// valid (though may be partially or completely overwritten), and
|
||||
/// the optional contains a short debug message that indicates the
|
||||
/// problem.
|
||||
std::optional<std::string> error = std::nullopt;
|
||||
|
||||
/// The runtime of the kernel in milliseconds, if measured. Whether the
|
||||
/// runtime is measured at all depends on the stream configuration
|
||||
/// passed to run(). 0 if not measured or if there was an error. This
|
||||
/// value is averaged over the total amount of runs actually done. Again,
|
||||
/// this is usually configured via the stream config.
|
||||
float runtime = 0.f;
|
||||
|
||||
/// @brief Utility function for constructing a RunResult from an unsupported operation.
|
||||
///
|
||||
/// @param msg A short debug message that will be included in the result.
|
||||
constexpr static RunResult not_supported(std::string_view msg)
|
||||
{
|
||||
return RunResult{.error = std::string(msg)};
|
||||
}
|
||||
|
||||
/// @brief Utility function for constructing a RunResult from an average runtime,
|
||||
/// indicating a successful operation.
|
||||
///
|
||||
/// @param runtime The runtime of the kernel in milliseconds.
|
||||
constexpr static RunResult from_runtime(const float runtime)
|
||||
{
|
||||
return RunResult{.runtime = runtime};
|
||||
}
|
||||
|
||||
/// @brief Returns whether this algorithm executed successfully.
|
||||
///
|
||||
/// In this case there should be no message in `error`.
|
||||
bool is_supported() const { return !this->error.has_value(); }
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const RunResult& result)
|
||||
{
|
||||
if(result.error.has_value())
|
||||
return os << "invalid run (" << result.error.value() << ")";
|
||||
else
|
||||
return os << "successful run (" << result.runtime << " ms)";
|
||||
}
|
||||
|
||||
/// @brief Invoke a device operation created by CK Builder.
|
||||
///
|
||||
/// This is the main function used to invoke a particular device operation
|
||||
@@ -318,13 +373,14 @@ ValidationReport validate(const Args<SIGNATURE>& args,
|
||||
/// @param outputs The output tensor data. The contents will be overwritten by
|
||||
/// this function.
|
||||
/// @param s_conf Stream config used to launch kernel.
|
||||
/// @return std::tuple<bool, float> - whether the problem is supported and
|
||||
/// kernel execution time (0.0f if s_conf time_kernel is false).
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @see RunResult
|
||||
template <auto SIGNATURE, typename Operation, typename StreamConf>
|
||||
std::tuple<bool, float> run(Operation& operation,
|
||||
[[nodiscard]] RunResult run(Operation& operation,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs,
|
||||
|
||||
@@ -5,6 +5,8 @@
|
||||
|
||||
#include <string_view>
|
||||
|
||||
#include "ck_tile/builder/testing/testing.hpp"
|
||||
|
||||
/// testing.hpp requires developers of a type of SIGNATURE to implement
|
||||
/// quite a lot of functionality for each SIGNATURE. For example, next
|
||||
/// to `Args`, `Inputs`, `Outputs`, `run`, they also have to define
|
||||
|
||||
@@ -168,7 +168,7 @@ add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
|
||||
)
|
||||
)
|
||||
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
|
||||
|
||||
set(BWD_WEIGHT_TESTS
|
||||
|
||||
@@ -1,23 +1,30 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
|
||||
#include "ck_tile/builder/testing/conv/bwd_weight_ck.hpp"
|
||||
#include "ck_tile/builder/testing/conv/reference.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
using ck_tile::test::MatchesReference;
|
||||
using ck_tile::test::SuccessfulRun;
|
||||
|
||||
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NGCW}},
|
||||
.input = {.config = {.layout = GNWC}},
|
||||
.weight = {.config = {.layout = GKXC}},
|
||||
.output = {.config = {.layout = NGKW}}};
|
||||
.output = {.config = {.layout = GNWK}}};
|
||||
|
||||
constexpr auto ALGORITHM =
|
||||
cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{}
|
||||
@@ -30,14 +37,58 @@ constexpr auto ALGORITHM =
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
|
||||
TEST(BwdWeight_1DBf16_CShuffle_V3, Create)
|
||||
{
|
||||
const auto expected_transfer_parameters = to_string(ALGORITHM);
|
||||
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3",
|
||||
expected_transfer_parameters,
|
||||
"Filter1x1Stride1Pad0",
|
||||
"NGCW,GKXC,NGKW",
|
||||
"GNWC,GKXC,GNWK",
|
||||
"PassThrough,PassThrough,PassThrough",
|
||||
"Intrawave",
|
||||
"v2"});
|
||||
}
|
||||
|
||||
TEST(BwdWeight_1DBf16_CShuffle_V3, Execution)
|
||||
{
|
||||
if(!ck_tile::get_device_name().starts_with("gfx9"))
|
||||
{
|
||||
// Note: XDL kernel
|
||||
GTEST_SKIP() << "unsupported architecture";
|
||||
}
|
||||
|
||||
ckt::Args<SIGNATURE> args = {
|
||||
.lengths =
|
||||
{
|
||||
.batch_size = 16,
|
||||
.groups = 1,
|
||||
.input_channels = 32,
|
||||
.output_channels = 48,
|
||||
.image = {.width = 64},
|
||||
.filter = {.width = 1},
|
||||
},
|
||||
.filter_strides = {.width = 1},
|
||||
.filter_dilation = {.width = 1},
|
||||
.input_left_pad = {.width = 0},
|
||||
.input_right_pad = {.width = 0},
|
||||
.a_elementwise_op = {},
|
||||
.b_elementwise_op = {},
|
||||
.cde_elementwise_op = {},
|
||||
};
|
||||
|
||||
auto inputs = ckt::alloc_inputs(args);
|
||||
auto outputs = ckt::alloc_outputs(args);
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
|
||||
ckt::init_inputs(args, inputs.get());
|
||||
|
||||
auto conv = Instance{};
|
||||
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
|
||||
|
||||
auto ref_conv = Reference{};
|
||||
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
|
||||
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
|
||||
}
|
||||
|
||||
@@ -4,8 +4,9 @@
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_ck.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_reference.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd_ck.hpp"
|
||||
#include "ck_tile/builder/testing/conv/reference.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
@@ -14,6 +15,7 @@ namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
using ck_tile::test::MatchesReference;
|
||||
using ck_tile::test::SuccessfulRun;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
@@ -50,10 +52,11 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, Create)
|
||||
"MNKPadding"});
|
||||
}
|
||||
|
||||
TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
|
||||
TEST(Fwd2DFp16_CShufV3_GNHWC, Execution)
|
||||
{
|
||||
if(!ck_tile::get_device_name().starts_with("gfx9"))
|
||||
{
|
||||
// Note: XDL kernel
|
||||
GTEST_SKIP() << "unsupported architecture";
|
||||
}
|
||||
|
||||
@@ -91,10 +94,10 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
|
||||
ckt::init_inputs(args, inputs.get());
|
||||
|
||||
auto conv = Instance{};
|
||||
ckt::run(conv, args, inputs.get(), outputs.get());
|
||||
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
|
||||
|
||||
auto ref_conv = Reference{};
|
||||
ckt::run(ref_conv, args, inputs.get(), reference.get());
|
||||
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
|
||||
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
|
||||
}
|
||||
|
||||
@@ -1,35 +1,47 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
|
||||
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
|
||||
#include "ck_tile/builder/testing/conv/reference.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "utils/ckb_conv_tile_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
namespace {
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
using enum ck_tile::builder::TensorLayout;
|
||||
using ck_tile::test::MatchesReference;
|
||||
using ck_tile::test::SuccessfulRun;
|
||||
|
||||
TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
|
||||
constexpr auto SIGNATURE = cku::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = ckb::DataType::FP16,
|
||||
.accumulation_data_type = ckb::DataType::FP32,
|
||||
.input = {.config = {.layout = NHWGC}},
|
||||
.weight = {.config = {.layout = GKYXC}},
|
||||
.output = {.config = {.layout = NHWGK}}};
|
||||
|
||||
constexpr auto ALGORITHM =
|
||||
cku::ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(ckb::TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(cku::TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(cku::TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(cku::TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(ckt::TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
|
||||
|
||||
TEST(BwdWeight_2D_FP16_NHWGC, Create)
|
||||
{
|
||||
constexpr ConvSignature BwdWeightConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto BwdWeightConvAlgorithm =
|
||||
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
|
||||
.with_tile_specializations(TileConvSpecialization::DEFAULT)
|
||||
.with_tile_thread_block(TileThreadBlock_64x64x64)
|
||||
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
|
||||
.with_tile_transfer(TileTransfer_4x4x4)
|
||||
.with_tile_optimizations(TileOptimizations{
|
||||
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
|
||||
|
||||
using Builder = ConvBuilder<BwdWeightConvSignature, BwdWeightConvAlgorithm>;
|
||||
run_ck_tile_test<Builder>({
|
||||
cku::run_ck_tile_test<Builder>({
|
||||
"grouped_convolution_backward_weight",
|
||||
"fp16",
|
||||
"NHWGC_GKYXC_NHWGK",
|
||||
@@ -49,4 +61,38 @@ TEST(BwdWeightConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
TEST(BwdWeight_2D_FP16_NHWGC, Execution)
|
||||
{
|
||||
ckt::Args<SIGNATURE> args = {
|
||||
.lengths =
|
||||
{
|
||||
.batch_size = 2,
|
||||
.groups = 4,
|
||||
.input_channels = 32,
|
||||
.output_channels = 48,
|
||||
.image = {.width = 32, .height = 56},
|
||||
.filter = {.width = 3, .height = 3},
|
||||
},
|
||||
.filter_strides = {.width = 1, .height = 1},
|
||||
.filter_dilation = {.width = 1, .height = 1},
|
||||
.input_left_pad = {.width = 0, .height = 0},
|
||||
.input_right_pad = {.width = 0, .height = 0},
|
||||
.a_elementwise_op = {},
|
||||
.b_elementwise_op = {},
|
||||
.cde_elementwise_op = {},
|
||||
};
|
||||
|
||||
auto inputs = ckt::alloc_inputs(args);
|
||||
auto outputs = ckt::alloc_outputs(args);
|
||||
auto reference = ckt::alloc_outputs(args);
|
||||
|
||||
ckt::init_inputs(args, inputs.get());
|
||||
|
||||
auto conv = Instance{};
|
||||
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
|
||||
|
||||
auto ref_conv = Reference{};
|
||||
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
|
||||
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
|
||||
}
|
||||
|
||||
@@ -4,8 +4,8 @@
|
||||
#include "utils/ckb_conv_tile_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
#include "utils/conv_algorithm_type_utils.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_reference.hpp"
|
||||
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
|
||||
#include "ck_tile/builder/testing/conv/reference.hpp"
|
||||
#include "ck_tile/host/device_prop.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
@@ -13,6 +13,9 @@ namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
namespace cku = ck_tile::builder::test_utils;
|
||||
|
||||
using ck_tile::test::MatchesReference;
|
||||
using ck_tile::test::SuccessfulRun;
|
||||
|
||||
constexpr auto SIGNATURE =
|
||||
ckt::ConvSignature{.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::FORWARD,
|
||||
@@ -75,10 +78,10 @@ TEST(Fwd2DFp16_CShufV3_NHWGC, EndToEnd)
|
||||
ckt::init_inputs(args, inputs.get());
|
||||
|
||||
auto conv = Instance{};
|
||||
ckt::run(conv, args, inputs.get(), outputs.get());
|
||||
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
|
||||
|
||||
auto ref_conv = Reference{};
|
||||
ckt::run(ref_conv, args, inputs.get(), reference.get());
|
||||
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
|
||||
|
||||
EXPECT_THAT(outputs.get(), ck_tile::test::MatchesReference(args, reference.get()));
|
||||
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
|
||||
}
|
||||
|
||||
@@ -5,11 +5,14 @@
|
||||
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
|
||||
using ck_tile::test::HipError;
|
||||
using ck_tile::test::HipSuccess;
|
||||
using ck_tile::test::InstanceMatcher;
|
||||
using ck_tile::test::InstanceSet;
|
||||
using ck_tile::test::StringEqWithDiff;
|
||||
using ck_tile::test::SuccessfulRun;
|
||||
|
||||
TEST(InstanceSet, FromFactory)
|
||||
{
|
||||
@@ -107,3 +110,17 @@ TEST(HipStatusMatcher, Basic)
|
||||
EXPECT_THAT(hipSuccess, Not(HipError(hipErrorInvalidValue)));
|
||||
EXPECT_THAT(hipErrorOutOfMemory, Not(HipError(hipErrorInvalidValue)));
|
||||
}
|
||||
|
||||
TEST(RunResultMatcher, Basic)
|
||||
{
|
||||
EXPECT_THAT(ckt::RunResult::from_runtime(0), SuccessfulRun());
|
||||
EXPECT_THAT(ckt::RunResult::not_supported("test error"), Not(SuccessfulRun()));
|
||||
}
|
||||
|
||||
TEST(RunResultMatcher, ExplainMatchResult)
|
||||
{
|
||||
testing::StringMatchResultListener listener;
|
||||
EXPECT_TRUE(!ExplainMatchResult(
|
||||
SuccessfulRun(), ckt::RunResult::not_supported("test error"), &listener));
|
||||
EXPECT_THAT(listener.str(), StringEqWithDiff("run failed: test error"));
|
||||
}
|
||||
|
||||
@@ -339,4 +339,22 @@ void HipStatusMatcher::DescribeNegationTo(std::ostream* os) const
|
||||
return ::testing::MakeMatcher(new HipStatusMatcher(error));
|
||||
}
|
||||
|
||||
bool RunResultMatcher::MatchAndExplain(builder::test::RunResult actual,
|
||||
::testing::MatchResultListener* listener) const
|
||||
{
|
||||
if(actual.error.has_value() && listener)
|
||||
*listener << "run failed: " << actual.error.value();
|
||||
|
||||
return actual.is_supported();
|
||||
}
|
||||
|
||||
void RunResultMatcher::DescribeTo(std::ostream* os) const { *os << "successful run"; }
|
||||
|
||||
void RunResultMatcher::DescribeNegationTo(std::ostream* os) const { *os << "unsuccessful run"; }
|
||||
|
||||
::testing::Matcher<builder::test::RunResult> SuccessfulRun()
|
||||
{
|
||||
return ::testing::MakeMatcher(new RunResultMatcher());
|
||||
}
|
||||
|
||||
} // namespace ck_tile::test
|
||||
|
||||
@@ -161,6 +161,23 @@ struct HipStatusMatcher : public ::testing::MatcherInterface<hipError_t>
|
||||
/// @param error The error to expect.
|
||||
::testing::Matcher<hipError_t> HipError(hipError_t error);
|
||||
|
||||
/// @brief RunResult matcher
|
||||
///
|
||||
/// `ckt::run` returns a RunResult which indicates whether there was any
|
||||
/// problem while running the algorithm. This matcher is used to match those
|
||||
/// values.
|
||||
struct RunResultMatcher : public ::testing::MatcherInterface<builder::test::RunResult>
|
||||
{
|
||||
bool MatchAndExplain(builder::test::RunResult actual,
|
||||
::testing::MatchResultListener* listener) const override;
|
||||
void DescribeTo(std::ostream* os) const override;
|
||||
void DescribeNegationTo(std::ostream* os) const override;
|
||||
};
|
||||
|
||||
/// @brief Construct a Google Test matcher that checks that a ckt::run result
|
||||
/// was successful.
|
||||
::testing::Matcher<builder::test::RunResult> SuccessfulRun();
|
||||
|
||||
template <auto SIGNATURE>
|
||||
struct ReferenceOutputMatcher
|
||||
: public ::testing::MatcherInterface<builder::test::Outputs<SIGNATURE>>
|
||||
@@ -180,6 +197,21 @@ struct ReferenceOutputMatcher
|
||||
if(listener->IsInterested() && !errors.empty())
|
||||
{
|
||||
*listener << errors.size() << " tensors failed to validate";
|
||||
|
||||
for(const auto& e : errors)
|
||||
{
|
||||
*listener << "\n - " << e.tensor_name << ": ";
|
||||
|
||||
if(e.is_all_zero())
|
||||
*listener << "all elements in actual and expected tensors are zero";
|
||||
else
|
||||
{
|
||||
// Round to 2 digits
|
||||
const float percentage = e.wrong_elements * 10000 / e.total_elements / 100.f;
|
||||
*listener << e.wrong_elements << "/" << e.total_elements
|
||||
<< " incorrect elements (~" << percentage << "%)";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return errors.empty();
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "testing_utils.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_foreach.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
|
||||
@@ -296,5 +296,8 @@ TEST(MatchesReference, Incorrect)
|
||||
testing::StringMatchResultListener listener;
|
||||
EXPECT_TRUE(!ExplainMatchResult(MatchesReference(args, expected), actual, &listener));
|
||||
|
||||
EXPECT_THAT(listener.str(), StringEqWithDiff("1 tensors failed to validate"));
|
||||
EXPECT_THAT(listener.str(),
|
||||
StringEqWithDiff( //
|
||||
"1 tensors failed to validate\n"
|
||||
" - a: 625/625 incorrect elements (~100%)"));
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#include "../../builder/test/utils/ckb_conv_tile_test_configs.hpp"
|
||||
#include "ck_tile/builder/testing/conv_fwd_ck_tile.hpp"
|
||||
#include "ck_tile/builder/testing/conv/fwd.hpp"
|
||||
#include "ck_tile/builder/testing/conv/ck_tile.hpp"
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckt = ck_tile::builder::test;
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
|
||||
using Instance = Builder::Instance;
|
||||
|
||||
auto conv = Instance{};
|
||||
bool is_supported;
|
||||
float avg_time;
|
||||
std::tie(is_supported, avg_time) = ckt::run(conv, args, inputs, outputs, s_conf);
|
||||
return std::make_tuple(is_supported, avg_time, conv.GetInstanceString());
|
||||
auto conv = Instance{};
|
||||
ckt::RunResult result = ckt::run(conv, args, inputs, outputs, s_conf);
|
||||
return std::make_tuple(result.is_supported(), result.runtime, conv.GetInstanceString());
|
||||
|
||||
Reference in New Issue
Block a user