mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 20:51:23 +00:00
[CK_BUILDER] conv bwd weight testing (#3618)
* ck-builder: restructure testing conv In order to prepare for bwd of conv testing, this commit moves some files and types around so that we can reuse ckt::Args for both forward and backwards convolution. * ck-builder: decouple fwd_ck.hpp and fwd_reference.hpp from fwd.hpp This will allow us to more easily include fwd.hpp from backwards definitions, which is required for initializing bwd values. * ck-builder: fix layout of test_ckb_conv_bwd_weight_xdl_cshuffle_v3 Turns out that the supplied layout isn't actually supported... * ck-builder: ck and reference conv integration for bwd weight * ck-builder: ck bwd weight execution test * ck-builder: ckt::run support for ck-tile bwd weight * ck-builder: ck tile bwd weight execution test * ck-builder: extra debug printing in MatchesReference * ck-builder: make ckt::run return RunResult This type is more convenient than std::tuple, as it will allow us to use google test matchers with this in the future. * ck-builder: RunResult matcher Using EXPECT_THAT(..., SuccessfulRun()) will generate a check and a nice error message about how and why running an algorithm failed. * ck-builder: doc fixes * ck-builder: add missing headers
This commit is contained in:
@@ -3,7 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <optional>
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
#include <string>
|
||||
#include <iosfwd>
|
||||
|
||||
#include "ck_tile/builder/testing/tensor_descriptor.hpp"
|
||||
#include "ck_tile/builder/testing/tensor_buffer.hpp"
|
||||
@@ -288,6 +292,57 @@ ValidationReport validate(const Args<SIGNATURE>& args,
|
||||
Outputs<SIGNATURE> actual,
|
||||
Outputs<SIGNATURE> expected) = delete;
|
||||
|
||||
/// @brief This structure represents the result of a run operation.
|
||||
///
|
||||
/// The structure contains multiple fields with information about
|
||||
/// how the operation completed (or not). See those for more info.
|
||||
struct RunResult
|
||||
{
|
||||
/// If this value is not set to `std::nullopt`, there was a problem
|
||||
/// while running the algorithm. In this case, the outputs are not
|
||||
/// valid (though may be partially or completely overwritten), and
|
||||
/// the optional contains a short debug message that indicates the
|
||||
/// problem.
|
||||
std::optional<std::string> error = std::nullopt;
|
||||
|
||||
/// The runtime of the kernel in milliseconds, if measured. Whether the
|
||||
/// runtime is measured at all depends on the stream configuration
|
||||
/// passed to run(). 0 if not measured or if there was an error. This
|
||||
/// value is averaged over the total amount of runs actually done. Again,
|
||||
/// this is usually configured via the stream config.
|
||||
float runtime = 0.f;
|
||||
|
||||
/// @brief Utility function for constructing a RunResult from an unsupported operation.
|
||||
///
|
||||
/// @param msg A short debug message that will be included in the result.
|
||||
constexpr static RunResult not_supported(std::string_view msg)
|
||||
{
|
||||
return RunResult{.error = std::string(msg)};
|
||||
}
|
||||
|
||||
/// @brief Utility function for constructing a RunResult from an average runtime,
|
||||
/// indicating a successful operation.
|
||||
///
|
||||
/// @param runtime The runtime of the kernel in milliseconds.
|
||||
constexpr static RunResult from_runtime(const float runtime)
|
||||
{
|
||||
return RunResult{.runtime = runtime};
|
||||
}
|
||||
|
||||
/// @brief Returns whether this algorithm executed successfully.
|
||||
///
|
||||
/// In this case there should be no message in `error`.
|
||||
bool is_supported() const { return !this->error.has_value(); }
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const RunResult& result)
|
||||
{
|
||||
if(result.error.has_value())
|
||||
return os << "invalid run (" << result.error.value() << ")";
|
||||
else
|
||||
return os << "successful run (" << result.runtime << " ms)";
|
||||
}
|
||||
|
||||
/// @brief Invoke a device operation created by CK Builder.
|
||||
///
|
||||
/// This is the main function used to invoke a particular device operation
|
||||
@@ -318,13 +373,14 @@ ValidationReport validate(const Args<SIGNATURE>& args,
|
||||
/// @param outputs The output tensor data. The contents will be overwritten by
|
||||
/// this function.
|
||||
/// @param s_conf Stream config used to launch kernel.
|
||||
/// @return std::tuple<bool, float> - whether the problem is supported and
|
||||
/// kernel execution time (0.0f if s_conf time_kernel is false).
|
||||
/// @returns RunResult about how the operation completed (or not).
|
||||
///
|
||||
/// @note This function is explicitly deleted to generate compile errors
|
||||
/// for missing implementations.
|
||||
///
|
||||
/// @see RunResult
|
||||
template <auto SIGNATURE, typename Operation, typename StreamConf>
|
||||
std::tuple<bool, float> run(Operation& operation,
|
||||
[[nodiscard]] RunResult run(Operation& operation,
|
||||
const Args<SIGNATURE>& args,
|
||||
const Inputs<SIGNATURE>& inputs,
|
||||
const Outputs<SIGNATURE>& outputs,
|
||||
|
||||
Reference in New Issue
Block a user