mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 21:51:28 +00:00
[CK_BUILDER] Replace reference conv with old ck implementation (#3604)
* ck-builder: remove SPATIAL_DIM parameter from ConvTensorLayouts This information is already in the SIGNATURE, so its pointless to pass it separately. This streamlines the interface of those functions a bit. Also touches up the style of those files in general. * ck-builder: implement reference conv using old ck The old ck implementation is more featureful and better tested. * ck-builder: replace test_reference_execution reference with old ck This strips out the ck-tile gpu reference implementation completely. * ck-builder: clean up test_reference_execution - Remove unneccesary messages - Replace EXPECT_TRUE(true) with EXPECT_NO_THROW()
This commit is contained in:
@@ -23,7 +23,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvBwdWeightDlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvBwdWeightMultiDWmmaV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvBwdWeightMultiDXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvBwdWeightTwoStageWmmaV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvBwdWeightTwoStageXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvBwdWeightWmmaFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvBwdWeightWmmaV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvBwdWeightXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvBwdWeightXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -24,7 +24,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdDlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdLargeTensorFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdXdlV3Factory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdWmmaFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFwdXdlFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -29,7 +29,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvTileFactory
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = internal::TileConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = internal::TileConvTensorLayouts<SIGNATURE>;
|
||||
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = internal::TileElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
@@ -172,10 +172,10 @@ struct LayoutToCK<TensorLayout::GNDHWK>
|
||||
using type = ck::tensor_layout::convolution::GNDHWK;
|
||||
};
|
||||
|
||||
template <TensorLayout Layout>
|
||||
template <TensorLayout LAYOUT>
|
||||
consteval auto TensorLayoutToCK()
|
||||
{
|
||||
return typename LayoutToCK<Layout>::type{};
|
||||
return typename LayoutToCK<LAYOUT>::type{};
|
||||
}
|
||||
|
||||
struct EmptyAuxiliaryTensorLayout
|
||||
@@ -183,49 +183,52 @@ struct EmptyAuxiliaryTensorLayout
|
||||
using type = ck::Tuple<>;
|
||||
};
|
||||
|
||||
template <auto AuxiliaryTensorConfigsArray, size_t... Indices>
|
||||
template <auto AUXILIARY_TENSOR_CONFIGS_ARRAY, size_t... Indices>
|
||||
consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence<Indices...>)
|
||||
{
|
||||
return ck::Tuple<
|
||||
decltype(TensorLayoutToCK<AuxiliaryTensorConfigsArray[Indices].layout>())...>{};
|
||||
decltype(TensorLayoutToCK<AUXILIARY_TENSOR_CONFIGS_ARRAY[Indices].layout>())...>{};
|
||||
}
|
||||
|
||||
template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM>
|
||||
template <auto AUXILIARY_TENSOR_CONFIGS_VALUE, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM>)
|
||||
struct AuxiliaryTensorLayouts
|
||||
{
|
||||
static constexpr auto Size = AuxiliaryTensorConfigsValue.size();
|
||||
using type = decltype(GetAuxiliaryTensorLayoutTuple<AuxiliaryTensorConfigsValue>(
|
||||
static constexpr auto Size = AUXILIARY_TENSOR_CONFIGS_VALUE.size();
|
||||
using type = decltype(GetAuxiliaryTensorLayoutTuple<AUXILIARY_TENSOR_CONFIGS_VALUE>(
|
||||
std::make_index_sequence<Size>{}));
|
||||
};
|
||||
|
||||
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
template <auto SIGNATURE>
|
||||
requires HasElementwiseOpWithAuxiliaryOperands<decltype(SIGNATURE.output)>
|
||||
consteval auto GetAuxiliaryTensorLayouts()
|
||||
{
|
||||
return AuxiliaryTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
|
||||
SPATIAL_DIM>{};
|
||||
return AuxiliaryTensorLayouts<SIGNATURE.output.operation.auxiliary_operand_configs,
|
||||
SIGNATURE.spatial_dim>{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
template <auto SIGNATURE>
|
||||
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(SIGNATURE.output)>)
|
||||
consteval auto GetAuxiliaryTensorLayouts()
|
||||
{
|
||||
return EmptyAuxiliaryTensorLayout{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> &&
|
||||
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
|
||||
template <auto SIGNATURE>
|
||||
requires ConvSpatialDim<SIGNATURE.spatial_dim> &&
|
||||
ValidConvInputLayoutForSpatialDim<SIGNATURE.input.config.layout,
|
||||
SIGNATURE.spatial_dim> &&
|
||||
ValidConvWeightLayoutForSpatialDim<SIGNATURE.weight.config.layout,
|
||||
SIGNATURE.spatial_dim> &&
|
||||
ValidConvOutputLayoutForSpatialDim<SIGNATURE.output.config.layout,
|
||||
SIGNATURE.spatial_dim>
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
using InLayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
|
||||
using WeiLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
|
||||
using OutLayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM>())::type;
|
||||
using InLayout = decltype(TensorLayoutToCK<SIGNATURE.input.config.layout>());
|
||||
using WeiLayout = decltype(TensorLayoutToCK<SIGNATURE.weight.config.layout>());
|
||||
using OutLayout = decltype(TensorLayoutToCK<SIGNATURE.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTensorLayouts<SIGNATURE>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -9,10 +9,10 @@
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
using ALayout = ck_tile::tensor_layout::convolution::NWGC;
|
||||
template <TensorLayout Layout>
|
||||
template <TensorLayout LAYOUT>
|
||||
struct LayoutToCKTile
|
||||
{
|
||||
static_assert(sizeof(UnsupportedEnumValue<Layout>) == 0,
|
||||
static_assert(sizeof(UnsupportedEnumValue<LAYOUT>) == 0,
|
||||
"Unsupported layout conversion to CK.");
|
||||
};
|
||||
|
||||
@@ -152,49 +152,52 @@ struct EmptyAuxiliaryTileTensorLayout
|
||||
using type = ck_tile::tuple<>;
|
||||
};
|
||||
|
||||
template <auto AuxiliaryTileTensorConfigsArray, size_t... Indices>
|
||||
template <auto AUXILIARY_TILE_TENSOR_CONFIGS_ARRAY, size_t... Indices>
|
||||
consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence<Indices...>)
|
||||
{
|
||||
return ck_tile::tuple<
|
||||
decltype(TensorLayoutToCKTile<AuxiliaryTileTensorConfigsArray[Indices].layout>())...>{};
|
||||
decltype(TensorLayoutToCKTile<AUXILIARY_TILE_TENSOR_CONFIGS_ARRAY[Indices].layout>())...>{};
|
||||
}
|
||||
|
||||
template <auto AuxiliaryTileTensorConfigsValue, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM>)
|
||||
template <auto AUXILIARY_TILE_TENSOR_CONFIGS_VALUE, size_t SPATIAL_DIM>
|
||||
requires ConvSpatialDim<SPATIAL_DIM>
|
||||
struct AuxiliaryTileTensorLayouts
|
||||
{
|
||||
static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size();
|
||||
using type = decltype(GetAuxiliaryTileTensorLayoutTuple<AuxiliaryTileTensorConfigsValue>(
|
||||
static constexpr auto Size = AUXILIARY_TILE_TENSOR_CONFIGS_VALUE.size();
|
||||
using type = decltype(GetAuxiliaryTileTensorLayoutTuple<AUXILIARY_TILE_TENSOR_CONFIGS_VALUE>(
|
||||
std::make_index_sequence<Size>{}));
|
||||
};
|
||||
|
||||
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
template <auto SIGNATURE>
|
||||
requires HasElementwiseOpWithAuxiliaryOperands<decltype(SIGNATURE.output)>
|
||||
consteval auto GetAuxiliaryTileTensorLayouts()
|
||||
{
|
||||
return AuxiliaryTileTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
|
||||
SPATIAL_DIM>{};
|
||||
return AuxiliaryTileTensorLayouts<SIGNATURE.output.operation.auxiliary_operand_configs,
|
||||
SIGNATURE.spatial_dim>{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
|
||||
template <auto SIGNATURE>
|
||||
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(SIGNATURE.output)>)
|
||||
consteval auto GetAuxiliaryTileTensorLayouts()
|
||||
{
|
||||
return EmptyAuxiliaryTileTensorLayout{};
|
||||
}
|
||||
|
||||
template <auto Signature, size_t SPATIAL_DIM>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> &&
|
||||
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
|
||||
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
|
||||
template <auto SIGNATURE>
|
||||
requires ConvSpatialDim<SIGNATURE.spatial_dim> &&
|
||||
ValidConvInputLayoutForSpatialDim<SIGNATURE.input.config.layout,
|
||||
SIGNATURE.spatial_dim> &&
|
||||
ValidConvWeightLayoutForSpatialDim<SIGNATURE.weight.config.layout,
|
||||
SIGNATURE.spatial_dim> &&
|
||||
ValidConvOutputLayoutForSpatialDim<SIGNATURE.output.config.layout,
|
||||
SIGNATURE.spatial_dim>
|
||||
struct TileConvTensorLayouts
|
||||
{
|
||||
using ALayout = decltype(TensorLayoutToCKTile<Signature.input.config.layout>());
|
||||
using BLayout = decltype(TensorLayoutToCKTile<Signature.weight.config.layout>());
|
||||
using ELayout = decltype(TensorLayoutToCKTile<Signature.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTileTensorLayouts<Signature, SPATIAL_DIM>())::type;
|
||||
using ALayout = decltype(TensorLayoutToCKTile<SIGNATURE.input.config.layout>());
|
||||
using BLayout = decltype(TensorLayoutToCKTile<SIGNATURE.weight.config.layout>());
|
||||
using ELayout = decltype(TensorLayoutToCKTile<SIGNATURE.output.config.layout>());
|
||||
using DsLayout = decltype(GetAuxiliaryTileTensorLayouts<SIGNATURE>())::type;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include <vector>
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Validation helper: Ensure reference implementation only receives PassThrough elementwise ops
|
||||
template <auto SIGNATURE>
|
||||
consteval void ValidateReferenceSignature()
|
||||
{
|
||||
using namespace ck_tile::builder;
|
||||
|
||||
// Check input elementwise operation
|
||||
static_assert(
|
||||
!HasTensorOp<decltype(SIGNATURE.input)> ||
|
||||
SIGNATURE.input.operation.elementwise_operation == ElementwiseOperation::PASS_THROUGH,
|
||||
"Reference implementation does not support elementwise operations on input tensor. "
|
||||
"Input operation must be PassThrough (or not specified).");
|
||||
|
||||
// Check weight elementwise operation
|
||||
static_assert(
|
||||
!HasTensorOp<decltype(SIGNATURE.weight)> ||
|
||||
SIGNATURE.weight.operation.elementwise_operation == ElementwiseOperation::PASS_THROUGH,
|
||||
"Reference implementation does not support elementwise operations on weight tensor. "
|
||||
"Weight operation must be PassThrough (or not specified).");
|
||||
|
||||
// Check output elementwise operation
|
||||
static_assert(
|
||||
!HasTensorOp<decltype(SIGNATURE.output)> ||
|
||||
SIGNATURE.output.operation.elementwise_operation == ElementwiseOperation::PASS_THROUGH,
|
||||
"Reference implementation does not support elementwise operations on output tensor. "
|
||||
"Output operation must be PassThrough (or not specified).");
|
||||
}
|
||||
|
||||
// Common argument structure for reference convolution implementations
|
||||
// Template parameters allow different const qualifiers for each direction
|
||||
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
|
||||
struct ReferenceConvArgument
|
||||
{
|
||||
InPtrType input_;
|
||||
WeiPtrType weight_;
|
||||
OutPtrType output_;
|
||||
int G_, N_, K_, C_;
|
||||
std::vector<ck_tile::long_index_t> input_spatial_;
|
||||
std::vector<ck_tile::long_index_t> filter_spatial_;
|
||||
std::vector<ck_tile::long_index_t> output_spatial_;
|
||||
std::vector<ck_tile::long_index_t> strides_;
|
||||
std::vector<ck_tile::long_index_t> dilations_;
|
||||
std::vector<ck_tile::long_index_t> left_pads_;
|
||||
|
||||
ReferenceConvArgument(InPtrType input,
|
||||
WeiPtrType weight,
|
||||
OutPtrType output,
|
||||
int G,
|
||||
int N,
|
||||
int K,
|
||||
int C,
|
||||
const std::vector<ck_tile::long_index_t>& input_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& filter_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& output_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& strides,
|
||||
const std::vector<ck_tile::long_index_t>& dilations,
|
||||
const std::vector<ck_tile::long_index_t>& left_pads)
|
||||
: input_(input),
|
||||
weight_(weight),
|
||||
output_(output),
|
||||
G_(G),
|
||||
N_(N),
|
||||
K_(K),
|
||||
C_(C),
|
||||
input_spatial_(input_spatial),
|
||||
filter_spatial_(filter_spatial),
|
||||
output_spatial_(output_spatial),
|
||||
strides_(strides),
|
||||
dilations_(dilations),
|
||||
left_pads_(left_pads)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
// Common invoker structure for reference convolution implementations
|
||||
// Takes a callable (lambda or function pointer) to execute the actual convolution
|
||||
template <typename ArgumentType, typename ConvFunc>
|
||||
struct ReferenceConvInvoker
|
||||
{
|
||||
ConvFunc conv_func_;
|
||||
|
||||
explicit ReferenceConvInvoker(ConvFunc func) : conv_func_(func) {}
|
||||
|
||||
float Run(const ArgumentType* arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
(void)stream_config; // Unused for reference implementation
|
||||
|
||||
conv_func_(arg->input_,
|
||||
arg->weight_,
|
||||
arg->output_,
|
||||
arg->G_,
|
||||
arg->N_,
|
||||
arg->K_,
|
||||
arg->C_,
|
||||
arg->input_spatial_,
|
||||
arg->filter_spatial_,
|
||||
arg->output_spatial_,
|
||||
arg->strides_,
|
||||
arg->dilations_,
|
||||
arg->left_pads_);
|
||||
|
||||
return 0.0f; // Reference implementation doesn't track timing
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -3,15 +3,15 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp"
|
||||
#include "ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp"
|
||||
#include "ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/reference_common.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp"
|
||||
#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp"
|
||||
#include "ck/library/utility/convolution_parameter.hpp"
|
||||
#include <memory>
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
@@ -22,16 +22,23 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
StringLiteral VERSION>
|
||||
struct ReferenceFactory
|
||||
{
|
||||
// Validate that only PassThrough elementwise operations are specified
|
||||
static constexpr auto kValidation = (internal::ValidateReferenceSignature<SIGNATURE>(), 0);
|
||||
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
|
||||
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
|
||||
using InDataType = typename Types::InDataType;
|
||||
using WeiDataType = typename Types::WeiDataType;
|
||||
using OutDataType = typename Types::OutDataType;
|
||||
|
||||
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE>;
|
||||
using InLayout = typename Layouts::InLayout;
|
||||
using WeiLayout = typename Layouts::WeiLayout;
|
||||
using OutLayout = typename Layouts::OutLayout;
|
||||
|
||||
using Ops = factory::internal::ConvElementwiseOps<SIGNATURE>;
|
||||
using InElementwiseOp = typename Ops::InElementwiseOp;
|
||||
using WeiElementwiseOp = typename Ops::WeiElementwiseOp;
|
||||
using OutElementwiseOp = typename Ops::OutElementwiseOp;
|
||||
|
||||
struct Instance
|
||||
{
|
||||
// Store template parameters for InstanceTraits reflection
|
||||
@@ -39,91 +46,57 @@ struct ReferenceFactory
|
||||
static constexpr auto kAlgorithm = ALGORITHM;
|
||||
static constexpr auto kVersion = VERSION;
|
||||
|
||||
// Argument and Invoker types depend on direction
|
||||
// Forward: const input, const weight, mutable output
|
||||
// Backward Data: mutable input, const weight, const output_grad
|
||||
// Backward Weight: const input, mutable weight_grad, const output_grad
|
||||
|
||||
// Use appropriate Argument type based on direction
|
||||
using Argument = std::conditional_t<
|
||||
ConvDirectionIsForward<SIGNATURE>,
|
||||
internal::ReferenceConvArgument<const InDataType*, const WeiDataType*, OutDataType*>,
|
||||
std::conditional_t<
|
||||
ConvDirectionIsBackwardData<SIGNATURE>,
|
||||
internal::
|
||||
ReferenceConvArgument<InDataType*, const WeiDataType*, const OutDataType*>,
|
||||
internal::
|
||||
ReferenceConvArgument<const InDataType*, WeiDataType*, const OutDataType*>>>;
|
||||
|
||||
// Invoker calls the appropriate reference implementation based on direction
|
||||
struct Invoker
|
||||
/// @brief Invoke reference convolution
|
||||
///
|
||||
/// This is the primary overload to invoke reference convolution. As the underlying
|
||||
/// function requires it, this function accepts ConvParam directly.
|
||||
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
|
||||
static void Run(InPtrType* input,
|
||||
WeiPtrType* weight,
|
||||
OutPtrType* output,
|
||||
const ck::utils::conv::ConvParam& param,
|
||||
InElementwiseOp in_op = InElementwiseOp{},
|
||||
WeiElementwiseOp wei_op = WeiElementwiseOp{},
|
||||
OutElementwiseOp out_op = OutElementwiseOp{})
|
||||
{
|
||||
float Run(const Argument* arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
(void)stream_config; // Unused for reference implementation
|
||||
|
||||
if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
ck_tile::
|
||||
naive_grouped_conv_fwd<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
|
||||
arg->input_,
|
||||
arg->weight_,
|
||||
arg->output_,
|
||||
arg->G_,
|
||||
arg->N_,
|
||||
arg->K_,
|
||||
arg->C_,
|
||||
arg->input_spatial_,
|
||||
arg->filter_spatial_,
|
||||
arg->output_spatial_,
|
||||
arg->strides_,
|
||||
arg->dilations_,
|
||||
arg->left_pads_);
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
ck_tile::naive_grouped_conv_bwd_data<SPATIAL_DIM,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(arg->input_,
|
||||
arg->weight_,
|
||||
arg->output_,
|
||||
arg->G_,
|
||||
arg->N_,
|
||||
arg->K_,
|
||||
arg->C_,
|
||||
arg->input_spatial_,
|
||||
arg->filter_spatial_,
|
||||
arg->output_spatial_,
|
||||
arg->strides_,
|
||||
arg->dilations_,
|
||||
arg->left_pads_);
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
ck_tile::naive_grouped_conv_bwd_weight<SPATIAL_DIM,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(arg->input_,
|
||||
arg->weight_,
|
||||
arg->output_,
|
||||
arg->G_,
|
||||
arg->N_,
|
||||
arg->K_,
|
||||
arg->C_,
|
||||
arg->input_spatial_,
|
||||
arg->filter_spatial_,
|
||||
arg->output_spatial_,
|
||||
arg->strides_,
|
||||
arg->dilations_,
|
||||
arg->left_pads_);
|
||||
}
|
||||
|
||||
return 0.0f; // Reference implementation doesn't track timing
|
||||
ck::ref::naive_conv_fwd<InLayout, WeiLayout, OutLayout>(
|
||||
static_cast<const InDataType*>(input),
|
||||
static_cast<const WeiDataType*>(weight),
|
||||
static_cast<OutDataType*>(output),
|
||||
param,
|
||||
in_op,
|
||||
wei_op,
|
||||
out_op);
|
||||
}
|
||||
};
|
||||
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
ck::ref::naive_conv_bwd_data<InLayout, WeiLayout, OutLayout>(
|
||||
static_cast<InDataType*>(input),
|
||||
static_cast<const WeiDataType*>(weight),
|
||||
static_cast<const OutDataType*>(output),
|
||||
param,
|
||||
in_op,
|
||||
wei_op,
|
||||
out_op);
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
ck::ref::naive_conv_bwd_weight<InLayout, WeiLayout, OutLayout>(
|
||||
static_cast<const InDataType*>(input),
|
||||
static_cast<WeiDataType*>(weight),
|
||||
static_cast<const OutDataType*>(output),
|
||||
param,
|
||||
in_op,
|
||||
wei_op,
|
||||
out_op);
|
||||
}
|
||||
}
|
||||
|
||||
// Direct Run method (simpler interface, direction-agnostic)
|
||||
/// @brief Invoke reference convolution
|
||||
///
|
||||
/// Convenience overload to avoid having to construct ConvParam manually.
|
||||
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
|
||||
static void Run(InPtrType* input,
|
||||
WeiPtrType* weight,
|
||||
@@ -132,68 +105,27 @@ struct ReferenceFactory
|
||||
int N,
|
||||
int K,
|
||||
int C,
|
||||
const std::vector<ck_tile::long_index_t>& input_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& filter_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& output_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& strides,
|
||||
const std::vector<ck_tile::long_index_t>& dilations,
|
||||
const std::vector<ck_tile::long_index_t>& left_pads)
|
||||
const std::vector<ck::long_index_t>& input_spatial,
|
||||
const std::vector<ck::long_index_t>& filter_spatial,
|
||||
const std::vector<ck::long_index_t>& strides,
|
||||
const std::vector<ck::long_index_t>& dilations,
|
||||
const std::vector<ck::long_index_t>& left_pads,
|
||||
const std::vector<ck::long_index_t>& right_pads)
|
||||
{
|
||||
if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
ck_tile::naive_grouped_conv_fwd<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
|
||||
static_cast<const InDataType*>(input),
|
||||
static_cast<const WeiDataType*>(weight),
|
||||
static_cast<OutDataType*>(output),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial,
|
||||
filter_spatial,
|
||||
output_spatial,
|
||||
strides,
|
||||
dilations,
|
||||
left_pads);
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
ck_tile::
|
||||
naive_grouped_conv_bwd_data<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
|
||||
static_cast<InDataType*>(input),
|
||||
static_cast<const WeiDataType*>(weight),
|
||||
static_cast<const OutDataType*>(output),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial,
|
||||
filter_spatial,
|
||||
output_spatial,
|
||||
strides,
|
||||
dilations,
|
||||
left_pads);
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
ck_tile::naive_grouped_conv_bwd_weight<SPATIAL_DIM,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(
|
||||
static_cast<const InDataType*>(input),
|
||||
static_cast<WeiDataType*>(weight),
|
||||
static_cast<const OutDataType*>(output),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial,
|
||||
filter_spatial,
|
||||
output_spatial,
|
||||
strides,
|
||||
dilations,
|
||||
left_pads);
|
||||
}
|
||||
Run(input,
|
||||
weight,
|
||||
output,
|
||||
ck::utils::conv::ConvParam(SPATIAL_DIM,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
filter_spatial,
|
||||
input_spatial,
|
||||
strides,
|
||||
dilations,
|
||||
left_pads,
|
||||
right_pads));
|
||||
}
|
||||
|
||||
std::string GetTypeString() const
|
||||
@@ -209,41 +141,6 @@ struct ReferenceFactory
|
||||
return std::string("GPU_Reference_") + dir_str + "_" + std::to_string(SPATIAL_DIM) +
|
||||
"D";
|
||||
}
|
||||
|
||||
// Old CK interface: Create argument pointer
|
||||
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
|
||||
std::unique_ptr<Argument>
|
||||
MakeArgumentPointer(InPtrType input,
|
||||
WeiPtrType weight,
|
||||
OutPtrType output,
|
||||
int G,
|
||||
int N,
|
||||
int K,
|
||||
int C,
|
||||
const std::vector<ck_tile::long_index_t>& input_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& filter_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& output_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& strides,
|
||||
const std::vector<ck_tile::long_index_t>& dilations,
|
||||
const std::vector<ck_tile::long_index_t>& left_pads) const
|
||||
{
|
||||
return std::make_unique<Argument>(input,
|
||||
weight,
|
||||
output,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial,
|
||||
filter_spatial,
|
||||
output_spatial,
|
||||
strides,
|
||||
dilations,
|
||||
left_pads);
|
||||
}
|
||||
|
||||
// Old CK interface: Create invoker pointer
|
||||
std::unique_ptr<Invoker> MakeInvokerPointer() const { return std::make_unique<Invoker>(); }
|
||||
};
|
||||
};
|
||||
|
||||
|
||||
@@ -76,7 +76,7 @@ struct Args<SIGNATURE>
|
||||
using Ops = factory::internal::ConvElementwiseOps<SIGNATURE>;
|
||||
|
||||
// TODO: We shouldn't need to call into an internal namespace here.
|
||||
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
|
||||
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE>;
|
||||
|
||||
ConvTensorLengths<SPATIAL_DIM> lengths;
|
||||
|
||||
|
||||
@@ -32,27 +32,8 @@ concept RefConvInstance = requires(Conv& conv,
|
||||
const void* input,
|
||||
const void* weight,
|
||||
void* output,
|
||||
int G,
|
||||
int N,
|
||||
int K,
|
||||
int C,
|
||||
std::vector<long_index_t> dims) {
|
||||
{
|
||||
conv.Run(input,
|
||||
weight,
|
||||
output,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
dims, // input_spatial
|
||||
dims, // filter_spatial
|
||||
dims, // output_spatial
|
||||
dims, // strides
|
||||
dims, // dilations
|
||||
dims // left_pads
|
||||
)
|
||||
};
|
||||
ck::utils::conv::ConvParam param) {
|
||||
{ conv.Run(input, weight, output, param) };
|
||||
};
|
||||
|
||||
/// @brief `run()` specialization for forward convolution and the reference
|
||||
@@ -84,16 +65,6 @@ std::tuple<bool, float> run(RefConvInstance<SIGNATURE> auto& conv,
|
||||
// Just throw for now, but regard these as TODO items that should be resolved
|
||||
// eventually.
|
||||
|
||||
// Right pads are not supported right now for some reason.
|
||||
for(auto right_pad : param.input_right_pads_)
|
||||
{
|
||||
if(right_pad != 0)
|
||||
{
|
||||
std::cout << "TODO: Support right pad in reference conv" << std::endl;
|
||||
return std::make_tuple(false, 0.0f);
|
||||
}
|
||||
}
|
||||
|
||||
if(!args.make_input_descriptor().is_packed())
|
||||
{
|
||||
std::cout << "TODO: Support non-packed input tensor in reference conv" << std::endl;
|
||||
@@ -110,19 +81,7 @@ std::tuple<bool, float> run(RefConvInstance<SIGNATURE> auto& conv,
|
||||
return std::make_tuple(false, 0.0f);
|
||||
}
|
||||
|
||||
conv.Run(inputs.input,
|
||||
inputs.weight,
|
||||
outputs.output,
|
||||
param.G_,
|
||||
param.N_,
|
||||
param.K_,
|
||||
param.C_,
|
||||
param.input_spatial_lengths_,
|
||||
param.filter_spatial_lengths_,
|
||||
param.output_spatial_lengths_,
|
||||
param.conv_filter_strides_,
|
||||
param.conv_filter_dilations_,
|
||||
param.input_left_pads_);
|
||||
conv.Run(inputs.input, inputs.weight, outputs.output, param);
|
||||
return std::make_tuple(true, 0.0f);
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user