Merge commit 'cd8af997e6d1fde6bc4397bd6ab4fca46510e776' into develop

This commit is contained in:
assistant-librarian[bot]
2025-11-19 21:11:39 +00:00
parent 7ed276c492
commit ca48bf3b98
8 changed files with 375 additions and 4 deletions

View File

@@ -58,6 +58,8 @@
#include "ck_tile/builder/types.hpp"
#include "ck_tile/builder/versions.hpp"
#include "ck_tile/builder/conv_signature_utils.hpp"
namespace ck_tile::builder::factory_internal {
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
@@ -665,7 +667,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION =
@@ -762,7 +764,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION =
@@ -858,7 +860,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto FWD_CONV_SPECIALIZATION =
@@ -980,7 +982,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
SPATIAL_DIM,
ConvDirection::FORWARD>());
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto BASE_ALGORITHM = ALGORITHM.base_algorithm;

View File

@@ -0,0 +1,47 @@
// Copyright (C) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
/**********************************************
* constexpr helper functions for optional parameters
**********************************************/
template <auto Sig>
concept ProvidesElementwiseOperation = requires { Sig.elementwiseOperation; };
template <auto Sig>
concept ProvidesConvolutionDirection = requires { Sig.direction; };
template <auto Sig>
constexpr auto get_elementwise_operation()
{
if constexpr(ProvidesElementwiseOperation<Sig>)
{
return Sig.elementwise_operation;
}
else
{
return ElementwiseOperation::PASS_THROUGH;
}
}
template <auto Sig>
constexpr auto get_conv_direction()
{
if constexpr(ProvidesConvolutionDirection<Sig>)
{
return Sig.direction;
}
else
{
return ConvDirection::FORWARD;
}
}
} // namespace ck_tile::builder