mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK_BUILDER] Clean-up fwd conv builder implementation (#3110)
This commit is contained in:
@@ -218,6 +218,14 @@ struct ElementwiseOps<ElementwiseOperation::PASS_THROUGH>
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOps<ElementwiseOperation::SCALE>
|
||||
{
|
||||
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::Scale;
|
||||
};
|
||||
|
||||
// The algorithm specializations for the convolution and GEMM.
|
||||
template <typename CONV_ENUM>
|
||||
requires(
|
||||
@@ -365,6 +373,10 @@ consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V2)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v2;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V3)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v3;
|
||||
@@ -434,9 +446,6 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
// Check preconditions for the algorithm description.
|
||||
static_assert(SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
|
||||
"Only 2D and 3D convolutions are supported in this factory.");
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseGemm<AlgorithmType>,
|
||||
|
||||
Reference in New Issue
Block a user