[CK_Builder ]fixed accidental drop of get_elementwise_operation during merge and added usage of get_elementwise_operation() to other builder instances (#3238)

Fixed issues encountered during merge of #3192

* fixed accidental drop of get_elementwise_operation during merge and added call to get_elementwise_op to 4 other builders

* run clang-format

---------

Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
This commit is contained in:
kabrahamAMD
2025-11-19 21:31:05 +01:00
committed by GitHub
parent e6e2e04edb
commit 964f8e1f60
3 changed files with 54 additions and 4 deletions

View File

@@ -58,6 +58,8 @@
#include "ck_tile/builder/types.hpp"
#include "ck_tile/builder/versions.hpp"
#include "ck_tile/builder/conv_signature_utils.hpp"
namespace ck_tile::builder::factory_internal {
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
@@ -665,7 +667,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION =
@@ -762,7 +764,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION =
@@ -858,7 +860,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION =
@@ -980,7 +982,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;

View File

@@ -0,0 +1,47 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
/**********************************************
* constexpr helper functions for optional parameters
**********************************************/
template <auto Sig>
concept ProvidesElementwiseOperation = requires { Sig.elementwiseOperation; };
template <auto Sig>
concept ProvidesConvolutionDirection = requires { Sig.direction; };
template <auto Sig>
constexpr auto get_elementwise_operation()
{
if constexpr(ProvidesElementwiseOperation<Sig>)
{
return Sig.elementwise_operation;
}
else
{
return ElementwiseOperation::PASS_THROUGH;
}
}
template <auto Sig>
constexpr auto get_conv_direction()
{
if constexpr(ProvidesConvolutionDirection<Sig>)
{
return Sig.direction;
}
else
{
return ConvDirection::FORWARD;
}
}
} // namespace ck_tile::builder

View File

@@ -49,6 +49,7 @@ struct ConvSignatureWithInvalidOptionalParams
ckb::GroupConvDeviceOp device_operation =
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
};
static_assert(!ckb::ConvSignatureDescriptor<ConvSignatureWithInvalidOptionalParams>);
struct DefaultAlgorithm