[CK TILE] Enable CK TILE Conv Fwd tests in CI and fix check_err (#3624)

* [CK TILE] Enable CK TILE Conv Fwd tests in CI and fix check_err

* Update test_grouped_convnd_fwd_tile.cpp

* Update test_grouped_convnd_fwd_tile.cpp

* Update conv_tuning_params.hpp

* clang format fix

* Update CMakeLists.txt
This commit is contained in:
Bartłomiej Kocot
2026-01-27 10:04:11 +01:00
committed by GitHub
parent c190d8d61f
commit 3d67e6c492
14 changed files with 114 additions and 46 deletions

View File

@@ -58,6 +58,7 @@ consteval BlockGemmSpec SetBlockGemm()
case PipelineVersion::V3: version = ck::BlockGemmPipelineVersion::v3; break;
case PipelineVersion::V4: version = ck::BlockGemmPipelineVersion::v4; break;
case PipelineVersion::V5: version = ck::BlockGemmPipelineVersion::v5; break;
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM.";
default: throw "Unknown PipelineVersion";
@@ -92,6 +93,7 @@ consteval ck::PipelineVersion SetGridwiseGemmPipelineVersion()
case PipelineVersion::V3: throw "PipelineVersion::V3 is used only for stream-K.";
case PipelineVersion::V4: return ck_pipeline::v4;
case PipelineVersion::V5: throw "PipelineVersion::V5 cannot be used for gridwise GEMM.";
case PipelineVersion::V6: throw "PipelineVersion::V6 can be used only for CK TILE.";
case PipelineVersion::WEIGHT_ONLY: return ck_pipeline::weight_only;
default: throw "Unknown GridwiseGemmPipelineVersion";
}
@@ -137,6 +139,7 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
case PipelineVersion::V3: return ck_pipeline::v3;
case PipelineVersion::V4: return ck_pipeline::v4;
case PipelineVersion::V5: return ck_pipeline::v5;
case PipelineVersion::V6: throw "PipelineVersion::V6 is supported only for CK Tile.";
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
default: throw "Unknown block GEMM PipelineVersion";

View File

@@ -91,6 +91,13 @@ struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V5>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V6>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV6<PipelineProblem>;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
{
@@ -103,6 +110,7 @@ consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3;
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
case PipelineVersion::V6: return ck_tile_pipeline::COMPUTE_V6;
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
default: throw "Unknown block GEMM PipelineVersion";

View File

@@ -51,6 +51,9 @@ struct ValidationReport
/// The number of elements which were bitwise 0.
uint64_t zero_elements;
// Max error.
double max_error;
/// @brief Check whether both the output and reference tensor were both all zeros.
///
/// If both tensors are all zero, it indicates either an incorrect testing setup
@@ -133,11 +136,12 @@ bool ValidationReport::check(std::string_view tensor_name,
// Initial pass: count errors
// Allocate and reset counter
auto d_counters = alloc_buffer(sizeof(uint64_t) * 2);
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 2));
auto d_counters = alloc_buffer(sizeof(uint64_t) * 3);
check_hip(hipMemset(d_counters.get(), 0, sizeof(uint64_t) * 3));
auto d_error_count = &reinterpret_cast<uint64_t*>(d_counters.get())[0];
auto d_zero_count = &reinterpret_cast<uint64_t*>(d_counters.get())[1];
auto d_max_error = &reinterpret_cast<double*>(d_counters.get())[2];
tensor_foreach(descriptor.get_lengths(), [=](auto index) {
using CKType = typename factory::internal::DataTypeToCK<DT>::type;
@@ -157,6 +161,7 @@ bool ValidationReport::check(std::string_view tensor_name,
const auto r = static_cast<double>(type_convert<float>(b));
const auto err = std::abs(o - r);
atomicMax(d_max_error, err);
if(err > atol + rtol * std::abs(r) || !std::isfinite(o) || !std::isfinite(r))
{
// We expect the number of errors to be very low, so just use an atomic
@@ -188,6 +193,8 @@ bool ValidationReport::check(std::string_view tensor_name,
check_hip(hipMemcpy(&error_count, d_error_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
uint64_t zero_count = 0;
check_hip(hipMemcpy(&zero_count, d_zero_count, sizeof(uint64_t), hipMemcpyDeviceToHost));
double max_error = 0;
check_hip(hipMemcpy(&max_error, d_max_error, sizeof(double), hipMemcpyDeviceToHost));
// TODO: Gather detailed coordinates.
@@ -196,6 +203,7 @@ bool ValidationReport::check(std::string_view tensor_name,
.wrong_elements = error_count,
.total_elements = descriptor.get_element_size(),
.zero_elements = zero_count,
.max_error = max_error,
});
return reports_.back().is_ok();

View File

@@ -157,6 +157,7 @@ enum class PipelineVersion
V3,
V4,
V5,
V6,
WEIGHT_ONLY
};
@@ -328,6 +329,7 @@ inline std::string_view to_string(PipelineVersion ver)
case V3: return "V3";
case V4: return "V4";
case V5: return "V5";
case V6: return "V6";
case WEIGHT_ONLY: return "WEIGHT_ONLY";
default: return "Unknown";
}