mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-21 21:39:15 +00:00
Merge commit 'cd8af997e6d1fde6bc4397bd6ab4fca46510e776' into develop
This commit is contained in:
@@ -58,6 +58,8 @@
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
|
||||
#include "ck_tile/builder/conv_signature_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory_internal {
|
||||
|
||||
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
|
||||
@@ -665,7 +667,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
@@ -762,7 +764,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
@@ -858,7 +860,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
@@ -980,7 +982,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
SPATIAL_DIM,
|
||||
ConvDirection::FORWARD>());
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
/**********************************************
|
||||
* constexpr helper functions for optional parameters
|
||||
**********************************************/
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesElementwiseOperation = requires { Sig.elementwiseOperation; };
|
||||
|
||||
template <auto Sig>
|
||||
concept ProvidesConvolutionDirection = requires { Sig.direction; };
|
||||
|
||||
template <auto Sig>
|
||||
constexpr auto get_elementwise_operation()
|
||||
{
|
||||
if constexpr(ProvidesElementwiseOperation<Sig>)
|
||||
{
|
||||
return Sig.elementwise_operation;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ElementwiseOperation::PASS_THROUGH;
|
||||
}
|
||||
}
|
||||
|
||||
template <auto Sig>
|
||||
constexpr auto get_conv_direction()
|
||||
{
|
||||
if constexpr(ProvidesConvolutionDirection<Sig>)
|
||||
{
|
||||
return Sig.direction;
|
||||
}
|
||||
else
|
||||
{
|
||||
return ConvDirection::FORWARD;
|
||||
}
|
||||
}
|
||||
} // namespace ck_tile::builder
|
||||
Reference in New Issue
Block a user