[CK_BUILDER] conv bwd weight testing (#3618)

* ck-builder: restructure testing conv

In order to prepare for bwd of conv testing, this commit moves some
files and types around so that we can reuse ckt::Args for both forward
and backwards convolution.

* ck-builder: decouple fwd_ck.hpp and fwd_reference.hpp from fwd.hpp

This will allow us to more easily include fwd.hpp from backwards
definitions, which is required for initializing bwd values.

* ck-builder: fix layout of test_ckb_conv_bwd_weight_xdl_cshuffle_v3

Turns out that the supplied layout isn't actually supported...

* ck-builder: ck and reference conv integration for bwd weight

* ck-builder: ck bwd weight execution test

* ck-builder: ckt::run support for ck-tile bwd weight

* ck-builder: ck tile bwd weight execution test

* ck-builder: extra debug printing in MatchesReference

* ck-builder: make ckt::run return RunResult

This type is more convenient than std::tuple, as it will allow us to
use google test matchers with this in the future.

* ck-builder: RunResult matcher

Using EXPECT_THAT(..., SuccessfulRun()) will generate a check and a nice error
message about how and why running an algorithm failed.

* ck-builder: doc fixes

* ck-builder: add missing headers
This commit is contained in:
Robin Voetter
2026-01-26 23:50:15 +01:00
committed by GitHub
parent 8654c0628f
commit cc75948d1c
27 changed files with 939 additions and 262 deletions

View File

@@ -3,7 +3,11 @@
#pragma once
#include <optional>
#include <concepts>
#include <string_view>
#include <string>
#include <iosfwd>
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
#include "ck_tile/builder/testing/tensor_buffer.hpp"
@@ -288,6 +292,57 @@ ValidationReport validate(const Args<SIGNATURE>& args,
Outputs<SIGNATURE> actual,
Outputs<SIGNATURE> expected) = delete;
/// @brief This structure represents the result of a run operation.
///
/// The structure contains multiple fields with information about
/// how the operation completed (or not). See those for more info.
struct RunResult
{
/// If this value is not set to `std::nullopt`, there was a problem
/// while running the algorithm. In this case, the outputs are not
/// valid (though may be partially or completely overwritten), and
/// the optional contains a short debug message that indicates the
/// problem.
std::optional<std::string> error = std::nullopt;
/// The runtime of the kernel in milliseconds, if measured. Whether the
/// runtime is measured at all depends on the stream configuration
/// passed to run(). 0 if not measured or if there was an error. This
/// value is averaged over the total amount of runs actually done. Again,
/// this is usually configured via the stream config.
float runtime = 0.f;
/// @brief Utility function for constructing a RunResult from an unsupported operation.
///
/// @param msg A short debug message that will be included in the result.
constexpr static RunResult not_supported(std::string_view msg)
{
return RunResult{.error = std::string(msg)};
}
/// @brief Utility function for constructing a RunResult from an average runtime,
/// indicating a successful operation.
///
/// @param runtime The runtime of the kernel in milliseconds.
constexpr static RunResult from_runtime(const float runtime)
{
return RunResult{.runtime = runtime};
}
/// @brief Returns whether this algorithm executed successfully.
///
/// In this case there should be no message in `error`.
bool is_supported() const { return !this->error.has_value(); }
};
inline std::ostream& operator<<(std::ostream& os, const RunResult& result)
{
if(result.error.has_value())
return os << "invalid run (" << result.error.value() << ")";
else
return os << "successful run (" << result.runtime << " ms)";
}
/// @brief Invoke a device operation created by CK Builder.
///
/// This is the main function used to invoke a particular device operation
@@ -318,13 +373,14 @@ ValidationReport validate(const Args<SIGNATURE>& args,
/// @param outputs The output tensor data. The contents will be overwritten by
/// this function.
/// @param s_conf Stream config used to launch kernel.
/// @return std::tuple<bool, float> - whether the problem is supported and
/// kernel execution time (0.0f if s_conf time_kernel is false).
/// @returns RunResult about how the operation completed (or not).
///
/// @note This function is explicitly deleted to generate compile errors
/// for missing implementations.
///
/// @see RunResult
template <auto SIGNATURE, typename Operation, typename StreamConf>
std::tuple<bool, float> run(Operation& operation,
[[nodiscard]] RunResult run(Operation& operation,
const Args<SIGNATURE>& args,
const Inputs<SIGNATURE>& inputs,
const Outputs<SIGNATURE>& outputs,