[CK_BUILDER] Clean-up fwd conv builder implementation (#3110)

This commit is contained in:
Ville Pietilä
2025-10-29 05:37:33 +02:00
committed by GitHub
parent 515e283091
commit 13e13ce359
10 changed files with 84 additions and 48 deletions

View File

@@ -218,6 +218,14 @@ struct ElementwiseOps<ElementwiseOperation::PASS_THROUGH>
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
};
template <>
struct ElementwiseOps<ElementwiseOperation::SCALE>
{
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementwiseOp = ck::tensor_operation::element_wise::Scale;
};
// The algorithm specializations for the convolution and GEMM.
template <typename CONV_ENUM>
requires(
@@ -365,6 +373,10 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
{
return ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(version == BlockGemmPipelineVersion::V2)
{
return ck::BlockGemmPipelineVersion::v2;
}
else if constexpr(version == BlockGemmPipelineVersion::V3)
{
return ck::BlockGemmPipelineVersion::v3;
@@ -434,9 +446,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
// Check preconditions for the algorithm description.
static_assert(SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
"Only 2D and 3D convolutions are supported in this factory.");
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesGridwiseGemm<AlgorithmType>,