mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK TILE] Enable CK TILE Conv Fwd tests in CI and fix check_err (#3624)
* [CK TILE] Enable CK TILE Conv Fwd tests in CI and fix check_err * Update test_grouped_convnd_fwd_tile.cpp * Update test_grouped_convnd_fwd_tile.cpp * Update conv_tuning_params.hpp * clang format fix * Update CMakeLists.txt
This commit is contained in:
@@ -58,6 +58,7 @@ consteval BlockGemmSpec SetBlockGemm()
|
||||
case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break;
|
||||
case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break;
|
||||
case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break;
|
||||
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM.";
|
||||
default: throw "Unknown PipelineVersion";
|
||||
@@ -92,6 +93,7 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
|
||||
case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K.";
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM.";
|
||||
case PipelineVersion::V6: throw "PipelineVersion::V6 can be used only for CK TILE.";
|
||||
case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only;
|
||||
default: throw "Unknown GridwiseGemmPipelineVersion";
|
||||
}
|
||||
@@ -137,6 +139,7 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
case PipelineVersion::V3: return ck_pipeline::v3;
|
||||
case PipelineVersion::V4: return ck_pipeline::v4;
|
||||
case PipelineVersion::V5: return ck_pipeline::v5;
|
||||
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
|
||||
default: throw "Unknown block GEMM PipelineVersion";
|
||||
|
||||
@@ -91,6 +91,13 @@ struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V5>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V6>
|
||||
{
|
||||
template <typename PipelineProblem>
|
||||
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
|
||||
{
|
||||
@@ -103,6 +110,7 @@ consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
|
||||
case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3;
|
||||
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
|
||||
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
|
||||
case PipelineVersion::V6: return ck_tile_pipeline::COMPUTE_V6;
|
||||
case PipelineVersion::WEIGHT_ONLY:
|
||||
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
|
||||
default: throw "Unknown block GEMM PipelineVersion";
|
||||
|
||||
@@ -51,6 +51,9 @@ struct ValidationReport
|
||||
/// The number of elements which were bitwise 0.
|
||||
uint64_t zero_elements;
|
||||
|
||||
// Max error.
|
||||
double max_error;
|
||||
|
||||
/// @brief Check whether both the output and reference tensor were both all zeros.
|
||||
///
|
||||
/// If both tensors are all zero, it indicates either an incorrect testing setup
|
||||
@@ -133,11 +136,12 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
// Initial pass: count errors
|
||||
|
||||
// Allocate and reset counter
|
||||
auto d_counters = alloc_buffer(sizeof(uint64_t) * 2);
|
||||
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2));
|
||||
auto d_counters = alloc_buffer(sizeof(uint64_t) * 3);
|
||||
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3));
|
||||
|
||||
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
|
||||
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
|
||||
auto d_max_error = &reinterpret_cast<double*>(d_counters.get())[2];
|
||||
|
||||
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
|
||||
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
|
||||
@@ -157,6 +161,7 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
const auto r = static_cast<double>(type_convert<float>(b));
|
||||
const auto err = std::abs(o - r);
|
||||
|
||||
atomicMax(d_max_error, err);
|
||||
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
|
||||
{
|
||||
// We expect the number of errors to be very low, so just use an atomic
|
||||
@@ -188,6 +193,8 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
|
||||
uint64_t zero_count = 0;
|
||||
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
|
||||
double max_error = 0;
|
||||
check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost));
|
||||
|
||||
// TODO: Gather detailed coordinates.
|
||||
|
||||
@@ -196,6 +203,7 @@ bool ValidationReport::check(std::string_view tensor_name,
|
||||
.wrong_elements = error_count,
|
||||
.total_elements = descriptor.get_element_size(),
|
||||
.zero_elements = zero_count,
|
||||
.max_error = max_error,
|
||||
});
|
||||
|
||||
return reports_.back().is_ok();
|
||||
|
||||
@@ -157,6 +157,7 @@ enum class PipelineVersion
|
||||
V3,
|
||||
V4,
|
||||
V5,
|
||||
V6,
|
||||
WEIGHT_ONLY
|
||||
};
|
||||
|
||||
@@ -328,6 +329,7 @@ inline std::string_view to_string(PipelineVersion ver)
|
||||
case V3: return "V3";
|
||||
case V4: return "V4";
|
||||
case V5: return "V5";
|
||||
case V6: return "V6";
|
||||
case WEIGHT_ONLY: return "WEIGHT_ONLY";
|
||||
default: return "Unknown";
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user