[CK_BUILDER] Replace reference conv with old ck implementation (#3604)

* ck-builder: remove SPATIAL_DIM parameter from ConvTensorLayouts

This information is already in the SIGNATURE, so its pointless to pass it
separately. This streamlines the interface of those functions a bit. Also
touches up the style of those files in general.

* ck-builder: implement reference conv using old ck

The old ck implementation is more featureful and better tested.

* ck-builder: replace test_reference_execution reference with old ck

This strips out the ck-tile gpu reference implementation completely.

* ck-builder: clean up test_reference_execution

- Remove unneccesary messages
- Replace EXPECT_TRUE(true) with EXPECT_NO_THROW()
This commit is contained in:
Robin Voetter
2026-01-21 19:18:47 +01:00
committed by GitHub
parent 0fbb3bb8c4
commit 1040d9b1f5
24 changed files with 291 additions and 1067 deletions

View File

@@ -23,7 +23,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightDlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightMultiDWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightMultiDXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightTwoStageWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightTwoStageXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightWmmaFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightWmmaV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvBwdWeightXdlV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -24,7 +24,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdDlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdLargeTensorFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdXdlV3Factory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdWmmaFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -26,7 +26,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvFwdXdlFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::ConvTensorLayouts<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Ops = internal::ConvElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -29,7 +29,7 @@ template <ConvSignatureDescriptor auto SIGNATURE,
struct ConvTileFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::TileConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = internal::TileConvTensorLayouts<SIGNATURE>;
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
using Ops = internal::TileElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);

View File

@@ -172,10 +172,10 @@ struct LayoutToCK<TensorLayout::GNDHWK>
using type = ck::tensor_layout::convolution::GNDHWK;
};
template <TensorLayout Layout>
template <TensorLayout LAYOUT>
consteval auto TensorLayoutToCK()
{
return typename LayoutToCK<Layout>::type{};
return typename LayoutToCK<LAYOUT>::type{};
}
struct EmptyAuxiliaryTensorLayout
@@ -183,49 +183,52 @@ struct EmptyAuxiliaryTensorLayout
using type = ck::Tuple<>;
};
template <auto AuxiliaryTensorConfigsArray, size_t... Indices>
template <auto AUXILIARY_TENSOR_CONFIGS_ARRAY, size_t... Indices>
consteval auto GetAuxiliaryTensorLayoutTuple(std::index_sequence<Indices...>)
{
return ck::Tuple<
decltype(TensorLayoutToCK<AuxiliaryTensorConfigsArray[Indices].layout>())...>{};
decltype(TensorLayoutToCK<AUXILIARY_TENSOR_CONFIGS_ARRAY[Indices].layout>())...>{};
}
template <auto AuxiliaryTensorConfigsValue, size_t SPATIAL_DIM>
template <auto AUXILIARY_TENSOR_CONFIGS_VALUE, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM>)
struct AuxiliaryTensorLayouts
{
static constexpr auto Size = AuxiliaryTensorConfigsValue.size();
using type = decltype(GetAuxiliaryTensorLayoutTuple<AuxiliaryTensorConfigsValue>(
static constexpr auto Size = AUXILIARY_TENSOR_CONFIGS_VALUE.size();
using type = decltype(GetAuxiliaryTensorLayoutTuple<AUXILIARY_TENSOR_CONFIGS_VALUE>(
std::make_index_sequence<Size>{}));
};
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
template <auto Signature, size_t SPATIAL_DIM>
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
template <auto SIGNATURE>
requires HasElementwiseOpWithAuxiliaryOperands<decltype(SIGNATURE.output)>
consteval auto GetAuxiliaryTensorLayouts()
{
return AuxiliaryTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
SPATIAL_DIM>{};
return AuxiliaryTensorLayouts<SIGNATURE.output.operation.auxiliary_operand_configs,
SIGNATURE.spatial_dim>{};
}
template <auto Signature, size_t SPATIAL_DIM>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
template <auto SIGNATURE>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(SIGNATURE.output)>)
consteval auto GetAuxiliaryTensorLayouts()
{
return EmptyAuxiliaryTensorLayout{};
}
template <auto Signature, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM> &&
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
template <auto SIGNATURE>
requires ConvSpatialDim<SIGNATURE.spatial_dim> &&
ValidConvInputLayoutForSpatialDim<SIGNATURE.input.config.layout,
SIGNATURE.spatial_dim> &&
ValidConvWeightLayoutForSpatialDim<SIGNATURE.weight.config.layout,
SIGNATURE.spatial_dim> &&
ValidConvOutputLayoutForSpatialDim<SIGNATURE.output.config.layout,
SIGNATURE.spatial_dim>
struct ConvTensorLayouts
{
using InLayout = decltype(TensorLayoutToCK<Signature.input.config.layout>());
using WeiLayout = decltype(TensorLayoutToCK<Signature.weight.config.layout>());
using OutLayout = decltype(TensorLayoutToCK<Signature.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTensorLayouts<Signature, SPATIAL_DIM>())::type;
using InLayout = decltype(TensorLayoutToCK<SIGNATURE.input.config.layout>());
using WeiLayout = decltype(TensorLayoutToCK<SIGNATURE.weight.config.layout>());
using OutLayout = decltype(TensorLayoutToCK<SIGNATURE.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTensorLayouts<SIGNATURE>())::type;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -9,10 +9,10 @@
namespace ck_tile::builder::factory::internal {
using ALayout = ck_tile::tensor_layout::convolution::NWGC;
template <TensorLayout Layout>
template <TensorLayout LAYOUT>
struct LayoutToCKTile
{
static_assert(sizeof(UnsupportedEnumValue<Layout>) == 0,
static_assert(sizeof(UnsupportedEnumValue<LAYOUT>) == 0,
"Unsupported layout conversion to CK.");
};
@@ -152,49 +152,52 @@ struct EmptyAuxiliaryTileTensorLayout
using type = ck_tile::tuple<>;
};
template <auto AuxiliaryTileTensorConfigsArray, size_t... Indices>
template <auto AUXILIARY_TILE_TENSOR_CONFIGS_ARRAY, size_t... Indices>
consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence<Indices...>)
{
return ck_tile::tuple<
decltype(TensorLayoutToCKTile<AuxiliaryTileTensorConfigsArray[Indices].layout>())...>{};
decltype(TensorLayoutToCKTile<AUXILIARY_TILE_TENSOR_CONFIGS_ARRAY[Indices].layout>())...>{};
}
template <auto AuxiliaryTileTensorConfigsValue, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM>)
template <auto AUXILIARY_TILE_TENSOR_CONFIGS_VALUE, size_t SPATIAL_DIM>
requires ConvSpatialDim<SPATIAL_DIM>
struct AuxiliaryTileTensorLayouts
{
static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size();
using type = decltype(GetAuxiliaryTileTensorLayoutTuple<AuxiliaryTileTensorConfigsValue>(
static constexpr auto Size = AUXILIARY_TILE_TENSOR_CONFIGS_VALUE.size();
using type = decltype(GetAuxiliaryTileTensorLayoutTuple<AUXILIARY_TILE_TENSOR_CONFIGS_VALUE>(
std::make_index_sequence<Size>{}));
};
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
template <auto Signature, size_t SPATIAL_DIM>
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
template <auto SIGNATURE>
requires HasElementwiseOpWithAuxiliaryOperands<decltype(SIGNATURE.output)>
consteval auto GetAuxiliaryTileTensorLayouts()
{
return AuxiliaryTileTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
SPATIAL_DIM>{};
return AuxiliaryTileTensorLayouts<SIGNATURE.output.operation.auxiliary_operand_configs,
SIGNATURE.spatial_dim>{};
}
template <auto Signature, size_t SPATIAL_DIM>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
template <auto SIGNATURE>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(SIGNATURE.output)>)
consteval auto GetAuxiliaryTileTensorLayouts()
{
return EmptyAuxiliaryTileTensorLayout{};
}
template <auto Signature, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM> &&
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
template <auto SIGNATURE>
requires ConvSpatialDim<SIGNATURE.spatial_dim> &&
ValidConvInputLayoutForSpatialDim<SIGNATURE.input.config.layout,
SIGNATURE.spatial_dim> &&
ValidConvWeightLayoutForSpatialDim<SIGNATURE.weight.config.layout,
SIGNATURE.spatial_dim> &&
ValidConvOutputLayoutForSpatialDim<SIGNATURE.output.config.layout,
SIGNATURE.spatial_dim>
struct TileConvTensorLayouts
{
using ALayout = decltype(TensorLayoutToCKTile<Signature.input.config.layout>());
using BLayout = decltype(TensorLayoutToCKTile<Signature.weight.config.layout>());
using ELayout = decltype(TensorLayoutToCKTile<Signature.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTileTensorLayouts<Signature, SPATIAL_DIM>())::type;
using ALayout = decltype(TensorLayoutToCKTile<SIGNATURE.input.config.layout>());
using BLayout = decltype(TensorLayoutToCKTile<SIGNATURE.weight.config.layout>());
using ELayout = decltype(TensorLayoutToCKTile<SIGNATURE.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTileTensorLayouts<SIGNATURE>())::type;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -1,118 +0,0 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/types.hpp"
#include <vector>
namespace ck_tile::builder::factory::internal {
// Validation helper: Ensure reference implementation only receives PassThrough elementwise ops
template <auto SIGNATURE>
consteval void ValidateReferenceSignature()
{
using namespace ck_tile::builder;
// Check input elementwise operation
static_assert(
!HasTensorOp<decltype(SIGNATURE.input)> ||
SIGNATURE.input.operation.elementwise_operation == ElementwiseOperation::PASS_THROUGH,
"Reference implementation does not support elementwise operations on input tensor. "
"Input operation must be PassThrough (or not specified).");
// Check weight elementwise operation
static_assert(
!HasTensorOp<decltype(SIGNATURE.weight)> ||
SIGNATURE.weight.operation.elementwise_operation == ElementwiseOperation::PASS_THROUGH,
"Reference implementation does not support elementwise operations on weight tensor. "
"Weight operation must be PassThrough (or not specified).");
// Check output elementwise operation
static_assert(
!HasTensorOp<decltype(SIGNATURE.output)> ||
SIGNATURE.output.operation.elementwise_operation == ElementwiseOperation::PASS_THROUGH,
"Reference implementation does not support elementwise operations on output tensor. "
"Output operation must be PassThrough (or not specified).");
}
// Common argument structure for reference convolution implementations
// Template parameters allow different const qualifiers for each direction
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
struct ReferenceConvArgument
{
InPtrType input_;
WeiPtrType weight_;
OutPtrType output_;
int G_, N_, K_, C_;
std::vector<ck_tile::long_index_t> input_spatial_;
std::vector<ck_tile::long_index_t> filter_spatial_;
std::vector<ck_tile::long_index_t> output_spatial_;
std::vector<ck_tile::long_index_t> strides_;
std::vector<ck_tile::long_index_t> dilations_;
std::vector<ck_tile::long_index_t> left_pads_;
ReferenceConvArgument(InPtrType input,
WeiPtrType weight,
OutPtrType output,
int G,
int N,
int K,
int C,
const std::vector<ck_tile::long_index_t>& input_spatial,
const std::vector<ck_tile::long_index_t>& filter_spatial,
const std::vector<ck_tile::long_index_t>& output_spatial,
const std::vector<ck_tile::long_index_t>& strides,
const std::vector<ck_tile::long_index_t>& dilations,
const std::vector<ck_tile::long_index_t>& left_pads)
: input_(input),
weight_(weight),
output_(output),
G_(G),
N_(N),
K_(K),
C_(C),
input_spatial_(input_spatial),
filter_spatial_(filter_spatial),
output_spatial_(output_spatial),
strides_(strides),
dilations_(dilations),
left_pads_(left_pads)
{
}
};
// Common invoker structure for reference convolution implementations
// Takes a callable (lambda or function pointer) to execute the actual convolution
template <typename ArgumentType, typename ConvFunc>
struct ReferenceConvInvoker
{
ConvFunc conv_func_;
explicit ReferenceConvInvoker(ConvFunc func) : conv_func_(func) {}
float Run(const ArgumentType* arg, const StreamConfig& stream_config = StreamConfig{})
{
(void)stream_config; // Unused for reference implementation
conv_func_(arg->input_,
arg->weight_,
arg->output_,
arg->G_,
arg->N_,
arg->K_,
arg->C_,
arg->input_spatial_,
arg->filter_spatial_,
arg->output_spatial_,
arg->strides_,
arg->dilations_,
arg->left_pads_);
return 0.0f; // Reference implementation doesn't track timing
}
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -3,15 +3,15 @@
#pragma once
#include "ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp"
#include "ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp"
#include "ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/types.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/reference_common.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_fwd_gpu.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_weight_gpu.hpp"
#include "ck/library/reference_tensor_operation/gpu/naive_conv_bwd_data_gpu.hpp"
#include "ck/library/utility/convolution_parameter.hpp"
#include <memory>
namespace ck_tile::builder::factory {
@@ -22,16 +22,23 @@ template <ConvSignatureDescriptor auto SIGNATURE,
StringLiteral VERSION>
struct ReferenceFactory
{
// Validate that only PassThrough elementwise operations are specified
static constexpr auto kValidation = (internal::ValidateReferenceSignature<SIGNATURE>(), 0);
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using Types = internal::ConvTensorDataTypes<SIGNATURE>;
using InDataType = typename Types::InDataType;
using WeiDataType = typename Types::WeiDataType;
using OutDataType = typename Types::OutDataType;
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE>;
using InLayout = typename Layouts::InLayout;
using WeiLayout = typename Layouts::WeiLayout;
using OutLayout = typename Layouts::OutLayout;
using Ops = factory::internal::ConvElementwiseOps<SIGNATURE>;
using InElementwiseOp = typename Ops::InElementwiseOp;
using WeiElementwiseOp = typename Ops::WeiElementwiseOp;
using OutElementwiseOp = typename Ops::OutElementwiseOp;
struct Instance
{
// Store template parameters for InstanceTraits reflection
@@ -39,91 +46,57 @@ struct ReferenceFactory
static constexpr auto kAlgorithm = ALGORITHM;
static constexpr auto kVersion = VERSION;
// Argument and Invoker types depend on direction
// Forward: const input, const weight, mutable output
// Backward Data: mutable input, const weight, const output_grad
// Backward Weight: const input, mutable weight_grad, const output_grad
// Use appropriate Argument type based on direction
using Argument = std::conditional_t<
ConvDirectionIsForward<SIGNATURE>,
internal::ReferenceConvArgument<const InDataType*, const WeiDataType*, OutDataType*>,
std::conditional_t<
ConvDirectionIsBackwardData<SIGNATURE>,
internal::
ReferenceConvArgument<InDataType*, const WeiDataType*, const OutDataType*>,
internal::
ReferenceConvArgument<const InDataType*, WeiDataType*, const OutDataType*>>>;
// Invoker calls the appropriate reference implementation based on direction
struct Invoker
/// @brief Invoke reference convolution
///
/// This is the primary overload to invoke reference convolution. As the underlying
/// function requires it, this function accepts ConvParam directly.
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
static void Run(InPtrType* input,
WeiPtrType* weight,
OutPtrType* output,
const ck::utils::conv::ConvParam& param,
InElementwiseOp in_op = InElementwiseOp{},
WeiElementwiseOp wei_op = WeiElementwiseOp{},
OutElementwiseOp out_op = OutElementwiseOp{})
{
float Run(const Argument* arg, const StreamConfig& stream_config = StreamConfig{})
if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
(void)stream_config; // Unused for reference implementation
if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
ck_tile::
naive_grouped_conv_fwd<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
arg->input_,
arg->weight_,
arg->output_,
arg->G_,
arg->N_,
arg->K_,
arg->C_,
arg->input_spatial_,
arg->filter_spatial_,
arg->output_spatial_,
arg->strides_,
arg->dilations_,
arg->left_pads_);
}
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
{
ck_tile::naive_grouped_conv_bwd_data<SPATIAL_DIM,
InDataType,
WeiDataType,
OutDataType>(arg->input_,
arg->weight_,
arg->output_,
arg->G_,
arg->N_,
arg->K_,
arg->C_,
arg->input_spatial_,
arg->filter_spatial_,
arg->output_spatial_,
arg->strides_,
arg->dilations_,
arg->left_pads_);
}
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
ck_tile::naive_grouped_conv_bwd_weight<SPATIAL_DIM,
InDataType,
WeiDataType,
OutDataType>(arg->input_,
arg->weight_,
arg->output_,
arg->G_,
arg->N_,
arg->K_,
arg->C_,
arg->input_spatial_,
arg->filter_spatial_,
arg->output_spatial_,
arg->strides_,
arg->dilations_,
arg->left_pads_);
}
return 0.0f; // Reference implementation doesn't track timing
ck::ref::naive_conv_fwd<InLayout, WeiLayout, OutLayout>(
static_cast<const InDataType*>(input),
static_cast<const WeiDataType*>(weight),
static_cast<OutDataType*>(output),
param,
in_op,
wei_op,
out_op);
}
};
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
{
ck::ref::naive_conv_bwd_data<InLayout, WeiLayout, OutLayout>(
static_cast<InDataType*>(input),
static_cast<const WeiDataType*>(weight),
static_cast<const OutDataType*>(output),
param,
in_op,
wei_op,
out_op);
}
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
ck::ref::naive_conv_bwd_weight<InLayout, WeiLayout, OutLayout>(
static_cast<const InDataType*>(input),
static_cast<WeiDataType*>(weight),
static_cast<const OutDataType*>(output),
param,
in_op,
wei_op,
out_op);
}
}
// Direct Run method (simpler interface, direction-agnostic)
/// @brief Invoke reference convolution
///
/// Convenience overload to avoid having to construct ConvParam manually.
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
static void Run(InPtrType* input,
WeiPtrType* weight,
@@ -132,68 +105,27 @@ struct ReferenceFactory
int N,
int K,
int C,
const std::vector<ck_tile::long_index_t>& input_spatial,
const std::vector<ck_tile::long_index_t>& filter_spatial,
const std::vector<ck_tile::long_index_t>& output_spatial,
const std::vector<ck_tile::long_index_t>& strides,
const std::vector<ck_tile::long_index_t>& dilations,
const std::vector<ck_tile::long_index_t>& left_pads)
const std::vector<ck::long_index_t>& input_spatial,
const std::vector<ck::long_index_t>& filter_spatial,
const std::vector<ck::long_index_t>& strides,
const std::vector<ck::long_index_t>& dilations,
const std::vector<ck::long_index_t>& left_pads,
const std::vector<ck::long_index_t>& right_pads)
{
if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
ck_tile::naive_grouped_conv_fwd<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
static_cast<const InDataType*>(input),
static_cast<const WeiDataType*>(weight),
static_cast<OutDataType*>(output),
G,
N,
K,
C,
input_spatial,
filter_spatial,
output_spatial,
strides,
dilations,
left_pads);
}
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
{
ck_tile::
naive_grouped_conv_bwd_data<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
static_cast<InDataType*>(input),
static_cast<const WeiDataType*>(weight),
static_cast<const OutDataType*>(output),
G,
N,
K,
C,
input_spatial,
filter_spatial,
output_spatial,
strides,
dilations,
left_pads);
}
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
{
ck_tile::naive_grouped_conv_bwd_weight<SPATIAL_DIM,
InDataType,
WeiDataType,
OutDataType>(
static_cast<const InDataType*>(input),
static_cast<WeiDataType*>(weight),
static_cast<const OutDataType*>(output),
G,
N,
K,
C,
input_spatial,
filter_spatial,
output_spatial,
strides,
dilations,
left_pads);
}
Run(input,
weight,
output,
ck::utils::conv::ConvParam(SPATIAL_DIM,
G,
N,
K,
C,
filter_spatial,
input_spatial,
strides,
dilations,
left_pads,
right_pads));
}
std::string GetTypeString() const
@@ -209,41 +141,6 @@ struct ReferenceFactory
return std::string("GPU_Reference_") + dir_str + "_" + std::to_string(SPATIAL_DIM) +
"D";
}
// Old CK interface: Create argument pointer
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
std::unique_ptr<Argument>
MakeArgumentPointer(InPtrType input,
WeiPtrType weight,
OutPtrType output,
int G,
int N,
int K,
int C,
const std::vector<ck_tile::long_index_t>& input_spatial,
const std::vector<ck_tile::long_index_t>& filter_spatial,
const std::vector<ck_tile::long_index_t>& output_spatial,
const std::vector<ck_tile::long_index_t>& strides,
const std::vector<ck_tile::long_index_t>& dilations,
const std::vector<ck_tile::long_index_t>& left_pads) const
{
return std::make_unique<Argument>(input,
weight,
output,
G,
N,
K,
C,
input_spatial,
filter_spatial,
output_spatial,
strides,
dilations,
left_pads);
}
// Old CK interface: Create invoker pointer
std::unique_ptr<Invoker> MakeInvokerPointer() const { return std::make_unique<Invoker>(); }
};
};

View File

@@ -76,7 +76,7 @@ struct Args<SIGNATURE>
using Ops = factory::internal::ConvElementwiseOps<SIGNATURE>;
// TODO: We shouldn't need to call into an internal namespace here.
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE>;
ConvTensorLengths<SPATIAL_DIM> lengths;

View File

@@ -32,27 +32,8 @@ concept RefConvInstance = requires(Conv& conv,
const void* input,
const void* weight,
void* output,
int G,
int N,
int K,
int C,
std::vector<long_index_t> dims) {
{
conv.Run(input,
weight,
output,
G,
N,
K,
C,
dims, // input_spatial
dims, // filter_spatial
dims, // output_spatial
dims, // strides
dims, // dilations
dims // left_pads
)
};
ck::utils::conv::ConvParam param) {
{ conv.Run(input, weight, output, param) };
};
/// @brief `run()` specialization for forward convolution and the reference
@@ -84,16 +65,6 @@ std::tuple<bool, float> run(RefConvInstance<SIGNATURE> auto& conv,
// Just throw for now, but regard these as TODO items that should be resolved
// eventually.
// Right pads are not supported right now for some reason.
for(auto right_pad : param.input_right_pads_)
{
if(right_pad != 0)
{
std::cout << "TODO: Support right pad in reference conv" << std::endl;
return std::make_tuple(false, 0.0f);
}
}
if(!args.make_input_descriptor().is_packed())
{
std::cout << "TODO: Support non-packed input tensor in reference conv" << std::endl;
@@ -110,19 +81,7 @@ std::tuple<bool, float> run(RefConvInstance<SIGNATURE> auto& conv,
return std::make_tuple(false, 0.0f);
}
conv.Run(inputs.input,
inputs.weight,
outputs.output,
param.G_,
param.N_,
param.K_,
param.C_,
param.input_spatial_lengths_,
param.filter_spatial_lengths_,
param.output_spatial_lengths_,
param.conv_filter_strides_,
param.conv_filter_dilations_,
param.input_left_pads_);
conv.Run(inputs.input, inputs.weight, outputs.output, param);
return std::make_tuple(true, 0.0f);
}