mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_Builder ]fixed accidental drop of get_elementwise_operation during merge and added usage of get_elementwise_operation() to other builder instances (#3238)
Fixed issues encountered during merge of #3192 * fixed accidental drop of get_elementwise_operation during merge and added call to get_elementwise_op to 4 other builders * run clang-format --------- Co-authored-by: Kevin Abraham <kevin.abraham@streamhpc.com>
This commit is contained in:
@@ -58,6 +58,8 @@
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory_internal {
|
||||
|
||||
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
|
||||
@@ -665,7 +667,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
@@ -762,7 +764,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
@@ -858,7 +860,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
@@ -980,7 +982,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
/**********************************************
|
||||
* constexpr helper functions for optional parameters
|
||||
**********************************************/
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesElementwiseOperation = requires { Sig.elementwiseOperation; };
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesConvolutionDirection = requires { Sig.direction; };
|
||||
|
||||
template <auto Sig>
|
||||
constexpr auto get_elementwise_operation()
|
||||
{
|
||||
if constexpr(ProvidesElementwiseOperation<Sig>)
|
||||
{
|
||||
return Sig.elementwise_operation;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
constexpr auto get_conv_direction()
|
||||
{
|
||||
if constexpr(ProvidesConvolutionDirection<Sig>)
|
||||
{
|
||||
return Sig.direction;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ConvDirection::FORWARD;
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile::builder
|
||||
@@ -49,6 +49,7 @@ struct ConvSignatureWithInvalidOptionalParams
|
||||
ckb::GroupConvDeviceOp device_operation =
|
||||
ckb::FwdGroupConvDeviceOperation::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3;
|
||||
};
|
||||
|
||||
static_assert(!ckb::ConvSignatureDescriptor<ConvSignatureWithInvalidOptionalParams>);
|
||||
|
||||
struct DefaultAlgorithm
|
||||
|
||||
Reference in New Issue
Block a user