mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 06:49:15 +00:00
[CK_BUILDER] Add GPU Reference Algorithm to CK Builder (#3381)
* [CK_BUILDER] Integrate GPU reference as ConvAlgorithm Add GPU reference as a ConvAlgorithm specialization, enabling: - Unified Builder API for reference and optimized kernels - Future ckProfiler integration for validation - First step toward numerical validation in Builder tests Changes: - Add ConvAlgorithmSpecialization::REFERENCE enum - Add ConvAlgorithm_Reference struct - Add IsReferenceAlgorithm concept - Create 3 reference factories (Forward, BwdData, BwdWeight) - Wire into conv_dispatcher - Add proof-of-concept test (passing) Test result: Can instantiate reference through Builder API * Add GPU reference execution tests - Reference kernel executes through Builder (459ms) - Both reference and optimized can instantiate - Tests passing Next: Implement utilities for comparison * Optimized Builder kernel execution works - MakeArgument pattern implemented - Builder-generated kernel executes successfully - Tests passing (451ms execution) Next: Add comparison * VALIDATION COMPLETE: Builder == Reference Builder-generated kernel output matches GPU reference! Test: Validate_Optimized_vs_Reference_Forward_2D_FP16 Result: PASS ✓ This proves CK Builder generates correct code! * Update to new Builder API All tests passing * Rename test file for clarity test_builder_kernel_execution -> test_builder_kernel_validation * Add all 3 directions support - Forward, Backward Data, Backward Weight - All reference factories working - Dispatcher wired for all directions - 9 tests passing Tests: - test_reference_execution: 3 tests (all directions) - test_optimized_execution: 3 tests (all directions) - test_builder_kernel_validation: 3 tests (fwd validated, bwd placeholders) * Add backward direction support - Backward data and weight dispatcher wiring - Fix factories for new API - All 3 directions tested - 9 tests passing * Refactor: Change IsReferenceAlgorithm from concept to consteval function Address review feedback: Use consteval function in dispatcher instead of concept, matching the pattern for other algorithms (Tile, XDL, WMMA, DL). - Remove IsReferenceAlgorithm concept from conv_algorithm_concepts.hpp - Add IsReferenceAlgorithm() consteval function to conv_dispatcher.hpp - Update dispatcher to use function call: IsReferenceAlgorithm<T>() - Remove redundant algorithm checks from reference factory requires clauses All tests passing (9/9). * Move Tile algorithm check outside direction block to support all directions * Implement MakeInvokerPointer interface and add random input validation - Implement full Argument/Invoker structs for old CK interface (not just nullptr) - Refactor with reference_common.hpp to reduce code duplication - Add random input validation tests: Builder vs direct GPU reference (all directions) - Fix layout: GNHWC -> NHWGC to match reference kernel expectations - All 12 tests pass with IDENTICAL results on random input * Move ConvAlgorithm_Reference to test/impl/conv_algorithm_types.hpp Keep types.hpp for data types only (enums), move algorithm descriptors to conv_algorithm_types.hpp as suggested by review. * Add static_assert to ensure reference factories only accept PassThrough operations Reference implementation doesn't support fused elementwise operations. Add compile-time validation to fail early with clear error message if non-PassThrough operations are specified on input, weight, or output. * Add InstanceTraits support for reference kernels - Store SIGNATURE/ALGORITHM/VERSION in Instance for reflection - Create shared ReferenceCommonTraits base for common properties - Add 3 direction-specific InstanceTraits specializations in one file - Include data type and layouts in instance_string output * Remove optimized kernel validation tests from reference-only branch * Use existing layout helper and organize reference tests Use LayoutToCK from conv_tensor_layout.hpp and move reference InstanceTraits test to validation folder. * Merge develop branch Fix DataType switch for new mixed precision types. * Fix comment spacing for CI * Convert IsReferenceAlgorithm from function to concept * Add reference tests to CI smoke tests * Consolidate 3 reference factories into single unified factory --------- Co-authored-by: Ville Pietilä <188998872+vpietila-amd@users.noreply.github.com>
This commit is contained in:
@@ -9,10 +9,11 @@
|
||||
// ## Design Overview
|
||||
//
|
||||
// The dispatcher operates in two phases:
|
||||
// 1. **Algorithm Identification**: Five `consteval` predicate functions (`IsXdlV3Algorithm`,
|
||||
// `IsXdlAlgorithm`, `IsWmmaAlgorithm`, `IsDlAlgorithm`, `IsLargeTensorAlgorithm`) inspect
|
||||
// the algorithm descriptor's structure to determine which kernel variant it satisfies.
|
||||
// Each predicate checks a specific set of concept constraints that define a kernel variant.
|
||||
// 1. **Algorithm Identification**: Six `consteval` predicate functions (`IsReferenceAlgorithm`,
|
||||
// `IsXdlV3Algorithm`, `IsXdlAlgorithm`, `IsWmmaAlgorithm`, `IsDlAlgorithm`,
|
||||
// `IsLargeTensorAlgorithm`) inspect the algorithm descriptor's structure to determine which
|
||||
// kernel variant it satisfies. Each predicate checks a specific set of concept constraints
|
||||
// that define a kernel variant.
|
||||
//
|
||||
// 2. **Factory Routing**: The main `make_conv_instance()` function uses `if constexpr`
|
||||
// to dispatch to the appropriate factory class based on both the convolution direction
|
||||
@@ -21,6 +22,9 @@
|
||||
//
|
||||
// ## Supported Kernel Variants
|
||||
//
|
||||
// - **Reference**: Simple reference implementation for validation. Only requires a specialization
|
||||
// field set to ConvAlgorithmSpecialization::REFERENCE.
|
||||
//
|
||||
// - **XDL V3**: Newer XDL-based pipeline using block GEMM structure. Requires fewer parameters
|
||||
// than standard XDL (e.g., uses `SpecifiesBlockGemm` instead of scheduling/prefetch configs).
|
||||
//
|
||||
@@ -59,6 +63,7 @@
|
||||
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
|
||||
#include "ck_tile/builder/factory/reference_factory.hpp"
|
||||
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
@@ -82,6 +87,13 @@ namespace ck_tile::builder::factory {
|
||||
//
|
||||
// TODO: Make this dispatch logic much more robust and clear for users.
|
||||
|
||||
// Reference algorithm (simplest implementation for validation)
|
||||
template <typename T>
|
||||
concept IsReferenceAlgorithm = ConvAlgorithmDescriptor<T> && requires {
|
||||
{ T::specialization } -> std::convertible_to<ConvAlgorithmSpecialization>;
|
||||
requires T::specialization == ConvAlgorithmSpecialization::REFERENCE;
|
||||
};
|
||||
|
||||
// CK Tile kernel
|
||||
template <typename T>
|
||||
concept IsTileAlgorithm = ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> &&
|
||||
@@ -132,11 +144,17 @@ constexpr auto make_conv_instance()
|
||||
{
|
||||
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
|
||||
|
||||
// Reference algorithm supports all directions
|
||||
if constexpr(IsReferenceAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ReferenceFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
// CK Tile supports common factory for each direction
|
||||
if constexpr(IsTileAlgorithm<AlgoType>)
|
||||
else if constexpr(IsTileAlgorithm<AlgoType>)
|
||||
{
|
||||
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
|
||||
}
|
||||
// Forward direction (supports most algorithm variants)
|
||||
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
if constexpr(IsXdlV3Algorithm<AlgoType>)
|
||||
@@ -164,23 +182,25 @@ constexpr auto make_conv_instance()
|
||||
static_assert(
|
||||
false,
|
||||
"No suitable forward convolution kernel factory found for the provided ALGORITHM. "
|
||||
"The ALGORITHM must satisfy requirements for one of: XDL V3, XDL, WMMA, DL (NHWC "
|
||||
"layout), or Large Tensor variant.");
|
||||
"The ALGORITHM must satisfy requirements for one of: Reference, Tile, XDL V3, XDL, "
|
||||
"WMMA, DL (NHWC layout), or Large Tensor variant.");
|
||||
}
|
||||
}
|
||||
// Backward data direction (will expand with more algorithms in the future)
|
||||
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"Backward data convolution is not yet supported. "
|
||||
"Only forward convolution (ConvDirection::FORWARD) is currently implemented.");
|
||||
static_assert(false,
|
||||
"Backward data convolution: Only reference and tile algorithms supported "
|
||||
"currently. "
|
||||
"Optimized kernels (XDL, WMMA, etc.) not yet implemented.");
|
||||
}
|
||||
// Backward weight direction (will expand with more algorithms in the future)
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
static_assert(
|
||||
false,
|
||||
"Backward weight convolution is not yet supported. "
|
||||
"Only forward convolution (ConvDirection::FORWARD) is currently implemented.");
|
||||
static_assert(false,
|
||||
"Backward weight convolution: Only reference and tile algorithms "
|
||||
"supported currently. "
|
||||
"Optimized kernels (XDL, WMMA, etc.) not yet implemented.");
|
||||
}
|
||||
else
|
||||
{
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include <vector>
|
||||
|
||||
namespace ck_tile::builder::factory::internal {
|
||||
|
||||
// Validation helper: Ensure reference implementation only receives PassThrough elementwise ops
|
||||
template <auto SIGNATURE>
|
||||
consteval void ValidateReferenceSignature()
|
||||
{
|
||||
using namespace ck_tile::builder;
|
||||
|
||||
// Check input elementwise operation
|
||||
static_assert(
|
||||
!HasTensorOp<decltype(SIGNATURE.input)> ||
|
||||
SIGNATURE.input.operation.elementwise_operation == ElementwiseOperation::PASS_THROUGH,
|
||||
"Reference implementation does not support elementwise operations on input tensor. "
|
||||
"Input operation must be PassThrough (or not specified).");
|
||||
|
||||
// Check weight elementwise operation
|
||||
static_assert(
|
||||
!HasTensorOp<decltype(SIGNATURE.weight)> ||
|
||||
SIGNATURE.weight.operation.elementwise_operation == ElementwiseOperation::PASS_THROUGH,
|
||||
"Reference implementation does not support elementwise operations on weight tensor. "
|
||||
"Weight operation must be PassThrough (or not specified).");
|
||||
|
||||
// Check output elementwise operation
|
||||
static_assert(
|
||||
!HasTensorOp<decltype(SIGNATURE.output)> ||
|
||||
SIGNATURE.output.operation.elementwise_operation == ElementwiseOperation::PASS_THROUGH,
|
||||
"Reference implementation does not support elementwise operations on output tensor. "
|
||||
"Output operation must be PassThrough (or not specified).");
|
||||
}
|
||||
|
||||
// Common argument structure for reference convolution implementations
|
||||
// Template parameters allow different const qualifiers for each direction
|
||||
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
|
||||
struct ReferenceConvArgument
|
||||
{
|
||||
InPtrType input_;
|
||||
WeiPtrType weight_;
|
||||
OutPtrType output_;
|
||||
int G_, N_, K_, C_;
|
||||
std::vector<ck_tile::long_index_t> input_spatial_;
|
||||
std::vector<ck_tile::long_index_t> filter_spatial_;
|
||||
std::vector<ck_tile::long_index_t> output_spatial_;
|
||||
std::vector<ck_tile::long_index_t> strides_;
|
||||
std::vector<ck_tile::long_index_t> dilations_;
|
||||
std::vector<ck_tile::long_index_t> left_pads_;
|
||||
|
||||
ReferenceConvArgument(InPtrType input,
|
||||
WeiPtrType weight,
|
||||
OutPtrType output,
|
||||
int G,
|
||||
int N,
|
||||
int K,
|
||||
int C,
|
||||
const std::vector<ck_tile::long_index_t>& input_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& filter_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& output_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& strides,
|
||||
const std::vector<ck_tile::long_index_t>& dilations,
|
||||
const std::vector<ck_tile::long_index_t>& left_pads)
|
||||
: input_(input),
|
||||
weight_(weight),
|
||||
output_(output),
|
||||
G_(G),
|
||||
N_(N),
|
||||
K_(K),
|
||||
C_(C),
|
||||
input_spatial_(input_spatial),
|
||||
filter_spatial_(filter_spatial),
|
||||
output_spatial_(output_spatial),
|
||||
strides_(strides),
|
||||
dilations_(dilations),
|
||||
left_pads_(left_pads)
|
||||
{
|
||||
}
|
||||
};
|
||||
|
||||
// Common invoker structure for reference convolution implementations
|
||||
// Takes a callable (lambda or function pointer) to execute the actual convolution
|
||||
template <typename ArgumentType, typename ConvFunc>
|
||||
struct ReferenceConvInvoker
|
||||
{
|
||||
ConvFunc conv_func_;
|
||||
|
||||
explicit ReferenceConvInvoker(ConvFunc func) : conv_func_(func) {}
|
||||
|
||||
float Run(const ArgumentType* arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
(void)stream_config; // Unused for reference implementation
|
||||
|
||||
conv_func_(arg->input_,
|
||||
arg->weight_,
|
||||
arg->output_,
|
||||
arg->G_,
|
||||
arg->N_,
|
||||
arg->K_,
|
||||
arg->C_,
|
||||
arg->input_spatial_,
|
||||
arg->filter_spatial_,
|
||||
arg->output_spatial_,
|
||||
arg->strides_,
|
||||
arg->dilations_,
|
||||
arg->left_pads_);
|
||||
|
||||
return 0.0f; // Reference implementation doesn't track timing
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory::internal
|
||||
@@ -0,0 +1,249 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ref/naive_grouped_conv_fwd_gpu.hpp"
|
||||
#include "ck_tile/ref/naive_grouped_conv_bwd_data_gpu.hpp"
|
||||
#include "ck_tile/ref/naive_grouped_conv_bwd_weight_gpu.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/reference_common.hpp"
|
||||
#include "ck_tile/core.hpp"
|
||||
#include <memory>
|
||||
|
||||
namespace ck_tile::builder::factory {
|
||||
|
||||
// Unified Factory for GPU Reference Convolution (all directions)
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
struct ReferenceFactory
|
||||
{
|
||||
// Validate that only PassThrough elementwise operations are specified
|
||||
static constexpr auto kValidation = (internal::ValidateReferenceSignature<SIGNATURE>(), 0);
|
||||
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Types = internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
|
||||
using InDataType = typename Types::ADataType;
|
||||
using WeiDataType = typename Types::BDataType;
|
||||
using OutDataType = typename Types::EDataType;
|
||||
|
||||
struct Instance
|
||||
{
|
||||
// Store template parameters for InstanceTraits reflection
|
||||
static constexpr auto kSignature = SIGNATURE;
|
||||
static constexpr auto kAlgorithm = ALGORITHM;
|
||||
static constexpr auto kVersion = VERSION;
|
||||
|
||||
// Argument and Invoker types depend on direction
|
||||
// Forward: const input, const weight, mutable output
|
||||
// Backward Data: mutable input, const weight, const output_grad
|
||||
// Backward Weight: const input, mutable weight_grad, const output_grad
|
||||
|
||||
// Use appropriate Argument type based on direction
|
||||
using Argument = std::conditional_t<
|
||||
ConvDirectionIsForward<SIGNATURE>,
|
||||
internal::ReferenceConvArgument<const InDataType*, const WeiDataType*, OutDataType*>,
|
||||
std::conditional_t<
|
||||
ConvDirectionIsBackwardData<SIGNATURE>,
|
||||
internal::
|
||||
ReferenceConvArgument<InDataType*, const WeiDataType*, const OutDataType*>,
|
||||
internal::
|
||||
ReferenceConvArgument<const InDataType*, WeiDataType*, const OutDataType*>>>;
|
||||
|
||||
// Invoker calls the appropriate reference implementation based on direction
|
||||
struct Invoker
|
||||
{
|
||||
float Run(const Argument* arg, const StreamConfig& stream_config = StreamConfig{})
|
||||
{
|
||||
(void)stream_config; // Unused for reference implementation
|
||||
|
||||
if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
ck_tile::
|
||||
naive_grouped_conv_fwd<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
|
||||
arg->input_,
|
||||
arg->weight_,
|
||||
arg->output_,
|
||||
arg->G_,
|
||||
arg->N_,
|
||||
arg->K_,
|
||||
arg->C_,
|
||||
arg->input_spatial_,
|
||||
arg->filter_spatial_,
|
||||
arg->output_spatial_,
|
||||
arg->strides_,
|
||||
arg->dilations_,
|
||||
arg->left_pads_);
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
ck_tile::naive_grouped_conv_bwd_data<SPATIAL_DIM,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(arg->input_,
|
||||
arg->weight_,
|
||||
arg->output_,
|
||||
arg->G_,
|
||||
arg->N_,
|
||||
arg->K_,
|
||||
arg->C_,
|
||||
arg->input_spatial_,
|
||||
arg->filter_spatial_,
|
||||
arg->output_spatial_,
|
||||
arg->strides_,
|
||||
arg->dilations_,
|
||||
arg->left_pads_);
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
ck_tile::naive_grouped_conv_bwd_weight<SPATIAL_DIM,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(arg->input_,
|
||||
arg->weight_,
|
||||
arg->output_,
|
||||
arg->G_,
|
||||
arg->N_,
|
||||
arg->K_,
|
||||
arg->C_,
|
||||
arg->input_spatial_,
|
||||
arg->filter_spatial_,
|
||||
arg->output_spatial_,
|
||||
arg->strides_,
|
||||
arg->dilations_,
|
||||
arg->left_pads_);
|
||||
}
|
||||
|
||||
return 0.0f; // Reference implementation doesn't track timing
|
||||
}
|
||||
};
|
||||
|
||||
// Direct Run method (simpler interface, direction-agnostic)
|
||||
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
|
||||
static void Run(InPtrType input,
|
||||
WeiPtrType weight,
|
||||
OutPtrType output,
|
||||
int G,
|
||||
int N,
|
||||
int K,
|
||||
int C,
|
||||
const std::vector<ck_tile::long_index_t>& input_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& filter_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& output_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& strides,
|
||||
const std::vector<ck_tile::long_index_t>& dilations,
|
||||
const std::vector<ck_tile::long_index_t>& left_pads)
|
||||
{
|
||||
if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
{
|
||||
ck_tile::naive_grouped_conv_fwd<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
|
||||
input,
|
||||
weight,
|
||||
output,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial,
|
||||
filter_spatial,
|
||||
output_spatial,
|
||||
strides,
|
||||
dilations,
|
||||
left_pads);
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
{
|
||||
ck_tile::
|
||||
naive_grouped_conv_bwd_data<SPATIAL_DIM, InDataType, WeiDataType, OutDataType>(
|
||||
input,
|
||||
weight,
|
||||
output,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial,
|
||||
filter_spatial,
|
||||
output_spatial,
|
||||
strides,
|
||||
dilations,
|
||||
left_pads);
|
||||
}
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
{
|
||||
ck_tile::naive_grouped_conv_bwd_weight<SPATIAL_DIM,
|
||||
InDataType,
|
||||
WeiDataType,
|
||||
OutDataType>(input,
|
||||
weight,
|
||||
output,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial,
|
||||
filter_spatial,
|
||||
output_spatial,
|
||||
strides,
|
||||
dilations,
|
||||
left_pads);
|
||||
}
|
||||
}
|
||||
|
||||
std::string GetTypeString() const
|
||||
{
|
||||
std::string dir_str;
|
||||
if constexpr(ConvDirectionIsForward<SIGNATURE>)
|
||||
dir_str = "Forward";
|
||||
else if constexpr(ConvDirectionIsBackwardData<SIGNATURE>)
|
||||
dir_str = "BackwardData";
|
||||
else if constexpr(ConvDirectionIsBackwardWeight<SIGNATURE>)
|
||||
dir_str = "BackwardWeight";
|
||||
|
||||
return std::string("GPU_Reference_") + dir_str + "_" + std::to_string(SPATIAL_DIM) +
|
||||
"D";
|
||||
}
|
||||
|
||||
// Old CK interface: Create argument pointer
|
||||
template <typename InPtrType, typename WeiPtrType, typename OutPtrType>
|
||||
std::unique_ptr<Argument>
|
||||
MakeArgumentPointer(InPtrType input,
|
||||
WeiPtrType weight,
|
||||
OutPtrType output,
|
||||
int G,
|
||||
int N,
|
||||
int K,
|
||||
int C,
|
||||
const std::vector<ck_tile::long_index_t>& input_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& filter_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& output_spatial,
|
||||
const std::vector<ck_tile::long_index_t>& strides,
|
||||
const std::vector<ck_tile::long_index_t>& dilations,
|
||||
const std::vector<ck_tile::long_index_t>& left_pads) const
|
||||
{
|
||||
return std::make_unique<Argument>(input,
|
||||
weight,
|
||||
output,
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial,
|
||||
filter_spatial,
|
||||
output_spatial,
|
||||
strides,
|
||||
dilations,
|
||||
left_pads);
|
||||
}
|
||||
|
||||
// Old CK interface: Create invoker pointer
|
||||
std::unique_ptr<Invoker> MakeInvokerPointer() const { return std::make_unique<Invoker>(); }
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
@@ -0,0 +1,191 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// InstanceTraits specializations for Reference convolution kernels
|
||||
//
|
||||
// This file provides compile-time reflection for all three reference kernel directions
|
||||
// (Forward, Backward Data, Backward Weight) using a shared base to reduce duplication.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "instance_traits.hpp"
|
||||
#include "instance_traits_util.hpp"
|
||||
#include "ck_tile/builder/factory/reference_factory.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
|
||||
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
|
||||
#include <sstream>
|
||||
|
||||
namespace ck_tile::reflect {
|
||||
|
||||
namespace internal {
|
||||
|
||||
// Common traits shared by all reference implementations
|
||||
template <auto SIGNATURE>
|
||||
struct ReferenceCommonTraits
|
||||
{
|
||||
// Spatial dimension
|
||||
static constexpr int kSpatialDim = SIGNATURE.spatial_dim;
|
||||
|
||||
// Layouts - map from enum to type using existing helper
|
||||
using InLayout =
|
||||
typename builder::factory::internal::LayoutToCK<SIGNATURE.input.config.layout>::type;
|
||||
using WeiLayout =
|
||||
typename builder::factory::internal::LayoutToCK<SIGNATURE.weight.config.layout>::type;
|
||||
using OutLayout =
|
||||
typename builder::factory::internal::LayoutToCK<SIGNATURE.output.config.layout>::type;
|
||||
|
||||
// Data types - extract from factory's type helper
|
||||
using Types = builder::factory::internal::FwdConvTensorDataTypes<SIGNATURE>;
|
||||
using ADataType = typename Types::ADataType;
|
||||
using BDataType = typename Types::BDataType;
|
||||
using EDataType = typename Types::EDataType;
|
||||
using AccDataType = float; // Reference uses float accumulation
|
||||
|
||||
// Elementwise operations - reference only supports PassThrough
|
||||
using AElementwiseOperation = ck_tile::element_wise::PassThrough;
|
||||
using BElementwiseOperation = ck_tile::element_wise::PassThrough;
|
||||
using CDEElementwiseOperation = ck_tile::element_wise::PassThrough;
|
||||
|
||||
// Reference has no block/tile configuration (simple kernel)
|
||||
// These are set to 0 to indicate "not applicable"
|
||||
static constexpr int kBlockSize = 0;
|
||||
static constexpr int kMPerBlock = 0;
|
||||
static constexpr int kNPerBlock = 0;
|
||||
static constexpr int kKPerBlock = 0;
|
||||
};
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// ============================================================================
|
||||
// InstanceTraits specialization for Reference Forward Convolution
|
||||
// ============================================================================
|
||||
template <typename Instance>
|
||||
requires(
|
||||
std::is_same_v<std::remove_const_t<decltype(Instance::kAlgorithm.specialization)>,
|
||||
builder::ConvAlgorithmSpecialization> &&
|
||||
(Instance::kAlgorithm.specialization == builder::ConvAlgorithmSpecialization::REFERENCE) &&
|
||||
builder::ConvDirectionIsForward<Instance::kSignature>)
|
||||
struct InstanceTraits<Instance> : internal::ReferenceCommonTraits<Instance::kSignature>
|
||||
{
|
||||
using Base = internal::ReferenceCommonTraits<Instance::kSignature>;
|
||||
|
||||
// Bring base class members into scope
|
||||
using Base::kBlockSize;
|
||||
using Base::kKPerBlock;
|
||||
using Base::kMPerBlock;
|
||||
using Base::kNPerBlock;
|
||||
using Base::kSpatialDim;
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::AElementwiseOperation;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BElementwiseOperation;
|
||||
using typename Base::CDEElementwiseOperation;
|
||||
using typename Base::EDataType;
|
||||
using typename Base::InLayout;
|
||||
using typename Base::OutLayout;
|
||||
using typename Base::WeiLayout;
|
||||
|
||||
static constexpr builder::ConvDirection direction = builder::ConvDirection::FORWARD;
|
||||
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "GPU_Reference_Forward_" << kSpatialDim << "D";
|
||||
oss << "_" << detail::type_name<ADataType>();
|
||||
oss << "_" << detail::layout_name<InLayout>();
|
||||
oss << "_" << detail::layout_name<WeiLayout>();
|
||||
oss << "_" << detail::layout_name<OutLayout>();
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// InstanceTraits specialization for Reference Backward Data Convolution
|
||||
// ============================================================================
|
||||
template <typename Instance>
|
||||
requires(
|
||||
std::is_same_v<std::remove_const_t<decltype(Instance::kAlgorithm.specialization)>,
|
||||
builder::ConvAlgorithmSpecialization> &&
|
||||
(Instance::kAlgorithm.specialization == builder::ConvAlgorithmSpecialization::REFERENCE) &&
|
||||
builder::ConvDirectionIsBackwardData<Instance::kSignature>)
|
||||
struct InstanceTraits<Instance> : internal::ReferenceCommonTraits<Instance::kSignature>
|
||||
{
|
||||
using Base = internal::ReferenceCommonTraits<Instance::kSignature>;
|
||||
|
||||
// Bring base class members into scope
|
||||
using Base::kBlockSize;
|
||||
using Base::kKPerBlock;
|
||||
using Base::kMPerBlock;
|
||||
using Base::kNPerBlock;
|
||||
using Base::kSpatialDim;
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::AElementwiseOperation;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BElementwiseOperation;
|
||||
using typename Base::CDEElementwiseOperation;
|
||||
using typename Base::EDataType;
|
||||
using typename Base::InLayout;
|
||||
using typename Base::OutLayout;
|
||||
using typename Base::WeiLayout;
|
||||
|
||||
static constexpr builder::ConvDirection direction = builder::ConvDirection::BACKWARD_DATA;
|
||||
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "GPU_Reference_BackwardData_" << kSpatialDim << "D";
|
||||
oss << "_" << detail::type_name<ADataType>();
|
||||
oss << "_" << detail::layout_name<InLayout>();
|
||||
oss << "_" << detail::layout_name<WeiLayout>();
|
||||
oss << "_" << detail::layout_name<OutLayout>();
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================================
|
||||
// InstanceTraits specialization for Reference Backward Weight Convolution
|
||||
// ============================================================================
|
||||
template <typename Instance>
|
||||
requires(
|
||||
std::is_same_v<std::remove_const_t<decltype(Instance::kAlgorithm.specialization)>,
|
||||
builder::ConvAlgorithmSpecialization> &&
|
||||
(Instance::kAlgorithm.specialization == builder::ConvAlgorithmSpecialization::REFERENCE) &&
|
||||
builder::ConvDirectionIsBackwardWeight<Instance::kSignature>)
|
||||
struct InstanceTraits<Instance> : internal::ReferenceCommonTraits<Instance::kSignature>
|
||||
{
|
||||
using Base = internal::ReferenceCommonTraits<Instance::kSignature>;
|
||||
|
||||
// Bring base class members into scope
|
||||
using Base::kBlockSize;
|
||||
using Base::kKPerBlock;
|
||||
using Base::kMPerBlock;
|
||||
using Base::kNPerBlock;
|
||||
using Base::kSpatialDim;
|
||||
using typename Base::AccDataType;
|
||||
using typename Base::ADataType;
|
||||
using typename Base::AElementwiseOperation;
|
||||
using typename Base::BDataType;
|
||||
using typename Base::BElementwiseOperation;
|
||||
using typename Base::CDEElementwiseOperation;
|
||||
using typename Base::EDataType;
|
||||
using typename Base::InLayout;
|
||||
using typename Base::OutLayout;
|
||||
using typename Base::WeiLayout;
|
||||
|
||||
static constexpr builder::ConvDirection direction = builder::ConvDirection::BACKWARD_WEIGHT;
|
||||
|
||||
static std::string instance_string()
|
||||
{
|
||||
std::ostringstream oss;
|
||||
oss << "GPU_Reference_BackwardWeight_" << kSpatialDim << "D";
|
||||
oss << "_" << detail::type_name<ADataType>();
|
||||
oss << "_" << detail::layout_name<InLayout>();
|
||||
oss << "_" << detail::layout_name<WeiLayout>();
|
||||
oss << "_" << detail::layout_name<OutLayout>();
|
||||
return oss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile::reflect
|
||||
@@ -248,7 +248,8 @@ enum class PipelineScheduler
|
||||
|
||||
enum class ConvAlgorithmSpecialization
|
||||
{
|
||||
LARGE_TENSOR
|
||||
LARGE_TENSOR,
|
||||
REFERENCE // GPU reference implementation for validation
|
||||
};
|
||||
|
||||
// toString methods for enum classes
|
||||
|
||||
@@ -84,21 +84,29 @@ add_ck_builder_test(test_ckb_conv_builder
|
||||
unit_conv_tensor_layout.cpp
|
||||
unit_conv_tensor_type.cpp
|
||||
unit_conv_thread_block.cpp
|
||||
unit_conv_tuning_params.cpp
|
||||
unit_conv_fwd_testing.cpp)
|
||||
target_link_libraries(test_ckb_conv_builder PRIVATE utility)
|
||||
unit_conv_tuning_params.cpp)
|
||||
|
||||
# Tests the inline diff utility used for comparing strings in tests assertions
|
||||
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
|
||||
|
||||
# Tests the inline diff utility used for comparing strings in tests assertions
|
||||
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
|
||||
|
||||
# Tests convolution trait selection and configuration
|
||||
add_ck_builder_test(test_ckb_conv_traits
|
||||
conv/ck/test_conv_traits.cpp)
|
||||
|
||||
# Tests convolution problem description and parameter handling
|
||||
add_ck_builder_test(test_ckb_conv_description
|
||||
test_conv_description.cpp)
|
||||
# GPU reference validation tests (in validation/ folder)
|
||||
# 1. Reference kernel execution and InstanceTraits
|
||||
add_ck_builder_test(test_ckb_reference_execution
|
||||
validation/test_reference_execution.cpp
|
||||
validation/test_reference_instance_traits.cpp)
|
||||
target_link_libraries(test_ckb_reference_execution PRIVATE utility)
|
||||
|
||||
# Note: Optimized kernel validation tests will be added after merging dev branch
|
||||
# with kernel Run() implementation from colleague's work
|
||||
|
||||
# Tests convolution trait selection and configuration
|
||||
add_ck_builder_test(test_ckb_conv_traits
|
||||
conv/ck/test_conv_traits.cpp)
|
||||
|
||||
# Tests convolution problem description and parameter handling
|
||||
add_ck_builder_test(test_ckb_conv_description
|
||||
test_conv_description.cpp)
|
||||
|
||||
################################################################################
|
||||
# REGRESSION TESTS - Integration Tests (With Kernel Compilation)
|
||||
################################################################################
|
||||
@@ -181,6 +189,7 @@ set(CKB_SMOKE_TESTS
|
||||
test_ckb_inline_diff
|
||||
test_ckb_conv_traits
|
||||
test_ckb_conv_description
|
||||
test_ckb_reference_execution
|
||||
)
|
||||
|
||||
foreach(test_target ${CKB_SMOKE_TESTS})
|
||||
|
||||
@@ -479,4 +479,13 @@ using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileTh
|
||||
TileConvSpecialization_,
|
||||
TileOptimizations_>;
|
||||
|
||||
// Reference algorithm descriptor - for GPU reference validation
|
||||
// This is a simple algorithm that requires no complex configuration,
|
||||
// just a specialization marker to identify it as a reference implementation.
|
||||
struct ConvAlgorithm_Reference
|
||||
{
|
||||
static constexpr auto specialization = ckb::ConvAlgorithmSpecialization::REFERENCE;
|
||||
// GPU reference uses simple algorithm, no tile configuration needed
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
1031
experimental/builder/test/validation/test_reference_execution.cpp
Normal file
1031
experimental/builder/test/validation/test_reference_execution.cpp
Normal file
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,117 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
// Test: Verify InstanceTraits works for Reference kernels
|
||||
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/reflect/instance_traits_reference.hpp"
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace ck_tile::builder::test;
|
||||
|
||||
TEST(ReferenceInstanceTraits, Forward_2D_FP16)
|
||||
{
|
||||
// Create a reference forward kernel
|
||||
constexpr ConvSignature sig{.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto ref_alg = ConvAlgorithm_Reference{};
|
||||
using RefKernel = ConvBuilder<sig, ref_alg>::Instance;
|
||||
|
||||
// Use InstanceTraits to query properties
|
||||
using Traits = ck_tile::reflect::InstanceTraits<RefKernel>;
|
||||
|
||||
// Verify spatial dimension
|
||||
EXPECT_EQ(Traits::kSpatialDim, 2);
|
||||
|
||||
// Verify direction
|
||||
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
|
||||
|
||||
// Verify data types
|
||||
EXPECT_TRUE((std::is_same_v<Traits::ADataType, ck::half_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Traits::BDataType, ck::half_t>));
|
||||
EXPECT_TRUE((std::is_same_v<Traits::EDataType, ck::half_t>));
|
||||
|
||||
// Verify layouts
|
||||
EXPECT_TRUE((std::is_same_v<Traits::InLayout, ck::tensor_layout::convolution::NHWGC>));
|
||||
EXPECT_TRUE((std::is_same_v<Traits::WeiLayout, ck::tensor_layout::convolution::GKYXC>));
|
||||
EXPECT_TRUE((std::is_same_v<Traits::OutLayout, ck::tensor_layout::convolution::NHWGK>));
|
||||
|
||||
// Verify elementwise operations (always PassThrough for reference)
|
||||
EXPECT_TRUE(
|
||||
(std::is_same_v<Traits::AElementwiseOperation, ck_tile::element_wise::PassThrough>));
|
||||
EXPECT_TRUE(
|
||||
(std::is_same_v<Traits::BElementwiseOperation, ck_tile::element_wise::PassThrough>));
|
||||
EXPECT_TRUE(
|
||||
(std::is_same_v<Traits::CDEElementwiseOperation, ck_tile::element_wise::PassThrough>));
|
||||
|
||||
// Verify block size is 0 (N/A for reference)
|
||||
EXPECT_EQ(Traits::kBlockSize, 0);
|
||||
|
||||
// Verify instance_string() - now includes data type and layouts!
|
||||
std::string instance_str = Traits::instance_string();
|
||||
EXPECT_EQ(instance_str, "GPU_Reference_Forward_2D_fp16_NHWGC_GKYXC_NHWGK");
|
||||
|
||||
std::cout << "✓ Forward InstanceTraits validated: " << instance_str << std::endl;
|
||||
}
|
||||
|
||||
TEST(ReferenceInstanceTraits, BackwardData_2D_FP16)
|
||||
{
|
||||
constexpr ConvSignature sig{.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_DATA,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto ref_alg = ConvAlgorithm_Reference{};
|
||||
using RefKernel = ConvBuilder<sig, ref_alg>::Instance;
|
||||
|
||||
using Traits = ck_tile::reflect::InstanceTraits<RefKernel>;
|
||||
|
||||
EXPECT_EQ(Traits::kSpatialDim, 2);
|
||||
EXPECT_EQ(Traits::direction, ConvDirection::BACKWARD_DATA);
|
||||
|
||||
std::string instance_str = Traits::instance_string();
|
||||
EXPECT_EQ(instance_str, "GPU_Reference_BackwardData_2D_fp16_NHWGC_GKYXC_NHWGK");
|
||||
|
||||
std::cout << "✓ Backward Data InstanceTraits validated: " << instance_str << std::endl;
|
||||
}
|
||||
|
||||
TEST(ReferenceInstanceTraits, BackwardWeight_2D_FP16)
|
||||
{
|
||||
constexpr ConvSignature sig{.spatial_dim = 2,
|
||||
.direction = ConvDirection::BACKWARD_WEIGHT,
|
||||
.data_type = DataType::FP16,
|
||||
.accumulation_data_type = DataType::FP32,
|
||||
.input = {.config = {.layout = TensorLayout::NHWGC}},
|
||||
.weight = {.config = {.layout = TensorLayout::GKYXC}},
|
||||
.output = {.config = {.layout = TensorLayout::NHWGK}}};
|
||||
|
||||
constexpr auto ref_alg = ConvAlgorithm_Reference{};
|
||||
using RefKernel = ConvBuilder<sig, ref_alg>::Instance;
|
||||
|
||||
using Traits = ck_tile::reflect::InstanceTraits<RefKernel>;
|
||||
|
||||
EXPECT_EQ(Traits::kSpatialDim, 2);
|
||||
EXPECT_EQ(Traits::direction, ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
std::string instance_str = Traits::instance_string();
|
||||
EXPECT_EQ(instance_str, "GPU_Reference_BackwardWeight_2D_fp16_NHWGC_GKYXC_NHWGK");
|
||||
|
||||
std::cout << "✓ Backward Weight InstanceTraits validated: " << instance_str << std::endl;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
Reference in New Issue
Block a user