[CK_BUILDER] Convert convolution traits to a struct with factory functions (#3547)

* Factor helpers out of conv_traits.hpp

* Create a non-templated conv_traits struct

* Migrate to new instance-specific instance_to_conv_traits functions

* Clean up reflection concepts

* Clean up ConvTraits helpers

* Update testing for convolution traits

This is a lot of cleanup on tests to have verbose coverage of feature
extraction, explicit tests for each supported device kernel, and
simple, readable test code.

* Address reviewer comments and resolve merge conflict
This commit is contained in:
John Shumway
2026-01-15 01:03:21 -08:00
committed by GitHub
parent 8705fdcb0c
commit 5122637215
17 changed files with 2288 additions and 1875 deletions

View File

@@ -108,7 +108,8 @@ target_link_libraries(test_ckb_reference_execution PRIVATE utility)
# Tests convolution trait selection and configuration
add_ck_builder_test(test_ckb_conv_traits
conv/ck/test_conv_traits.cpp
conv/ck/unit_instance_to_conv_traits.cpp)
conv/ck/unit_instance_to_conv_traits_features.cpp
conv/ck/unit_instance_to_conv_traits_instances.cpp)
# Tests convolution problem description and parameter handling
add_ck_builder_test(test_ckb_conv_description

View File

@@ -6,7 +6,7 @@
#include <concepts>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck_tile/builder/reflect/conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_to_conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
@@ -86,72 +86,72 @@ TEST_F(ConvTraitsTest, ConvFwdTraitsExtraction)
ck::half_t, // BComputeDataType
false>; // DirectLoad
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Use ConvTraitsTmpl to extract compile-time information
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
EXPECT_THAT(Traits::layout,
EXPECT_EQ(traits.spatial_dim, 2);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
EXPECT_EQ(Traits::data_type, DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
EXPECT_EQ(traits.thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
EXPECT_EQ(traits.tile_dims.m, 128);
EXPECT_EQ(traits.tile_dims.n, 128);
EXPECT_EQ(traits.tile_dims.k, 16);
// Verify A tile transfer info
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(Traits::a_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(Traits::a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(Traits::a_tile_transfer.transfer_params.lds_padding);
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(traits.a_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(traits.a_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(traits.a_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(traits.a_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(traits.a_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(traits.a_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(traits.a_tile_transfer.transfer_params.lds_padding);
// Verify B tile transfer info
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(Traits::b_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(Traits::b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(Traits::b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(Traits::b_tile_transfer.transfer_params.lds_padding);
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k0, 2);
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.m_or_n, 128);
EXPECT_EQ(traits.b_tile_transfer.tile_dimensions.k1, 8);
EXPECT_EQ(traits.b_tile_transfer.transfer_params.k1, 8);
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_dims, ElementsAre(4, 64, 1));
EXPECT_THAT(traits.b_tile_transfer.transfer_params.thread_cluster_order, ElementsAre(1, 0, 2));
EXPECT_THAT(traits.b_tile_transfer.transfer_params.src_access_order, ElementsAre(1, 0, 2));
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_vector_dim, 2);
EXPECT_EQ(traits.b_tile_transfer.transfer_params.src_scalar_per_vector, 8);
EXPECT_EQ(traits.b_tile_transfer.transfer_params.dst_scalar_per_vector_k1, 8);
EXPECT_TRUE(traits.b_tile_transfer.transfer_params.lds_padding);
// Verify warp GEMM params
EXPECT_EQ(Traits::warp_gemm.gemm_m, 32);
EXPECT_EQ(Traits::warp_gemm.gemm_n, 32);
EXPECT_EQ(Traits::warp_gemm.m_iter, 4);
EXPECT_EQ(Traits::warp_gemm.n_iter, 4);
EXPECT_EQ(traits.warp_gemm.gemm_m, 32);
EXPECT_EQ(traits.warp_gemm.gemm_n, 32);
EXPECT_EQ(traits.warp_gemm.m_iter, 4);
EXPECT_EQ(traits.warp_gemm.n_iter, 4);
// Verify output tile transfer info
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
EXPECT_EQ(Traits::c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
EXPECT_THAT(Traits::c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
EXPECT_EQ(Traits::c_tile_transfer.scalar_per_vector, 8);
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.m_gemms_per_shuffle, 1);
EXPECT_EQ(traits.c_tile_transfer.shuffle_params.n_gemms_per_shuffle, 1);
EXPECT_THAT(traits.c_tile_transfer.thread_cluster_dims, ElementsAre(1, 32, 1, 8));
EXPECT_EQ(traits.c_tile_transfer.scalar_per_vector, 8);
// Verify pipeline configuration
EXPECT_EQ(Traits::pipeline_scheduler, PipelineScheduler::INTRAWAVE);
EXPECT_EQ(Traits::pipeline_version, PipelineVersion::V1);
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE);
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
}
// Test ConvTraits with DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
@@ -214,30 +214,30 @@ TEST_F(ConvTraitsTest, ConvFwdBaseTraitsExtraction)
ck::LoopScheduler::Default, // LoopSched
1>; // NumGroupsToMerge
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Use ConvTraitsTmpl to extract compile-time information
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
EXPECT_THAT(Traits::layout,
EXPECT_EQ(traits.spatial_dim, 2);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
EXPECT_EQ(Traits::data_type, DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
EXPECT_EQ(traits.thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
EXPECT_EQ(traits.tile_dims.m, 128);
EXPECT_EQ(traits.tile_dims.n, 128);
EXPECT_EQ(traits.tile_dims.k, 16);
}
// Test ConvTraits with DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
@@ -298,29 +298,29 @@ TEST_F(ConvTraitsTest, ConvFwdLargeTensorTraitsExtraction)
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default>; // LoopSched
// Use ConvTraits to extract compile-time information
using Traits = ck_tile::reflect::conv::ConvTraits<DeviceInstance>;
// Use ConvTraitsTmpl to extract compile-time information
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(Traits::spatial_dim, 2);
EXPECT_EQ(Traits::direction, ConvDirection::FORWARD);
EXPECT_THAT(Traits::layout,
EXPECT_EQ(traits.spatial_dim, 2);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
EXPECT_EQ(Traits::data_type, DataType::FP16);
EXPECT_EQ(Traits::input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(Traits::output_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.input_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.weight_element_op, ElementwiseOperation::PASS_THROUGH);
EXPECT_EQ(traits.output_element_op, ElementwiseOperation::PASS_THROUGH);
// Verify specializations
EXPECT_EQ(Traits::gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(Traits::conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
EXPECT_EQ(traits.gemm_padding, ck_tile::builder::GemmPadding::DEFAULT);
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
// Verify algorithm information
EXPECT_EQ(Traits::thread_block_size, 256);
EXPECT_EQ(traits.thread_block_size, 256);
// Verify tile dimensions
EXPECT_EQ(Traits::tile_dims.m, 128);
EXPECT_EQ(Traits::tile_dims.n, 128);
EXPECT_EQ(Traits::tile_dims.k, 16);
EXPECT_EQ(traits.tile_dims.m, 128);
EXPECT_EQ(traits.tile_dims.n, 128);
EXPECT_EQ(traits.tile_dims.k, 16);
}
} // anonymous namespace

View File

@@ -0,0 +1,800 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// ============================================================================
// Unit Tests for Individual Conversion Functions
// ============================================================================
//
// PURPOSE:
// --------
// These tests verify individual conversion and extraction functions that
// transform raw CK kernel parameters into semantic types. Each test
// focuses on a single conversion function to ensure it correctly maps
// CK types to builder enums and structures.
//
// TEST COVERAGE:
// --------------
// 1. Enum Conversions:
// - Pipeline versions (BlockGemmPipelineVersion and PipelineVersion)
// - Pipeline schedulers (BlockGemmPipelineScheduler and LoopScheduler)
//
// 2. Elementwise Operations (14 operations):
// - PassThrough, Scale, Relu, Gelu, Sigmoid, Tanh, ScaleAdd
// - Silu, Swish, Elu, LeakyRelu, UnaryConvert, ConvScale, ConvScaleAdd
//
// 3. Convolution Properties:
// - Direction detection (Forward)
// - Specializations (Default, Filter1x1Pad0, Filter1x1Stride1Pad0,
// Filter3x3, OddC)
//
// 4. Layout Detection:
// - 1D layouts (GNWC, NWGC, NGCW)
// - 2D layouts (GNHWC, NHWGC, NGCHW with GKYXC/GKCYX)
// - 3D layouts (GNDHWC, NDHWGC, NGCDHW)
//
// 5. Data Type Detection:
// - FP16, BF16, FP32, I8
//
// 6. Pipeline Configuration:
// - Pipeline versions (V2, V3)
// - Schedulers (Interwave)
//
// 7. GEMM Padding Variations (17 types):
// - Default, MNK, M, N, K, MN, MK, NK
// - O, MO, NO, KO, MNO, MKO, NKO, MNKO
// ============================================================================
#include "ck/utility/scheduler_enum.hpp"
#include "ck_tile/builder/types.hpp"
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck_tile/builder/reflect/instance_to_conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
namespace {
using ::ck_tile::builder::ConvDirection;
using ::ck_tile::builder::DataType;
using ::ck_tile::builder::ElementwiseOperation;
using ::ck_tile::builder::GemmPadding;
using ::ck_tile::builder::PipelineScheduler;
using ::ck_tile::builder::PipelineVersion;
using ::ck_tile::builder::TensorLayout;
using ::testing::ElementsAre;
// ============================================================================
// Test Helper Templates
// ============================================================================
// These templates reduce boilerplate by providing sensible defaults for
// template parameters that don't vary in most tests.
// ============================================================================
namespace defaults {
// Default values used across most tests
static constexpr int kBlockSize = 256;
static constexpr int kMPerBlock = 128;
static constexpr int kNPerBlock = 128;
static constexpr int kKPerBlock = 16;
static constexpr int kAK1 = 8;
static constexpr int kBK1 = 8;
static constexpr int kMPerXDL = 32;
static constexpr int kNPerXDL = 32;
static constexpr int kMXdlPerWave = 4;
static constexpr int kNXdlPerWave = 4;
static constexpr int kABlockTransferSrcVectorDim = 2;
static constexpr int kABlockTransferSrcScalarPerVector = 8;
static constexpr int kABlockTransferDstScalarPerVector_AK1 = 8;
static constexpr int kABlockLdsExtraM = 1;
static constexpr int kBBlockTransferSrcVectorDim = 2;
static constexpr int kBBlockTransferSrcScalarPerVector = 8;
static constexpr int kBBlockTransferDstScalarPerVector_BK1 = 8;
static constexpr int kBBlockLdsExtraN = 1;
static constexpr int kCShuffleMXdlPerWavePerShuffle = 1;
static constexpr int kCShuffleNXdlPerWavePerShuffle = 1;
static constexpr int kCDEBlockTransferScalarPerVector_NPerBlock = 8;
static constexpr bool kDirectLoad = false;
using DefaultABlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>;
using DefaultABlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>;
using DefaultABlockTransferSrcAccessOrder = ck::Sequence<1, 0, 2>;
using DefaultBBlockTransferThreadClusterLengths = ck::Sequence<4, 64, 1>;
using DefaultBBlockTransferThreadClusterArrangeOrder = ck::Sequence<1, 0, 2>;
using DefaultBBlockTransferSrcAccessOrder = ck::Sequence<1, 0, 2>;
using DefaultCDEBlockTransferClusterLengths = ck::Sequence<1, 32, 1, 8>;
} // namespace defaults
// DeviceInstanceForTests - V3 variant with sensible defaults
template <int NDimSpatial = 2,
typename ALayout = ck::tensor_layout::convolution::GNHWC,
typename BLayout = ck::tensor_layout::convolution::GKYXC,
typename ELayout = ck::tensor_layout::convolution::GNHWK,
typename ADataType = ck::half_t,
typename BDataType = ck::half_t,
typename EDataType = ck::half_t,
typename AccDataType = float,
typename AElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
typename BElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
typename CDEElementwiseOperation = ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::ConvolutionForwardSpecialization ConvForwardSpecialization =
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
ck::tensor_operation::device::GemmSpecialization GemmSpec =
ck::tensor_operation::device::GemmSpecialization::Default,
ck::BlockGemmPipelineScheduler BlkGemmPipeSched =
ck::BlockGemmPipelineScheduler::Intrawave,
ck::BlockGemmPipelineVersion BlkGemmPipelineVer = ck::BlockGemmPipelineVersion::v1>
using DeviceInstanceForTests_V3 =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
NDimSpatial,
ALayout,
BLayout,
ck::Tuple<>,
ELayout,
ADataType,
BDataType,
AccDataType,
ADataType,
ck::Tuple<>,
EDataType,
AElementwiseOperation,
BElementwiseOperation,
CDEElementwiseOperation,
ConvForwardSpecialization,
GemmSpec,
defaults::kBlockSize,
defaults::kMPerBlock,
defaults::kNPerBlock,
defaults::kKPerBlock,
defaults::kAK1,
defaults::kBK1,
defaults::kMPerXDL,
defaults::kNPerXDL,
defaults::kMXdlPerWave,
defaults::kNXdlPerWave,
defaults::DefaultABlockTransferThreadClusterLengths,
defaults::DefaultABlockTransferThreadClusterArrangeOrder,
defaults::DefaultABlockTransferSrcAccessOrder,
defaults::kABlockTransferSrcVectorDim,
defaults::kABlockTransferSrcScalarPerVector,
defaults::kABlockTransferDstScalarPerVector_AK1,
defaults::kABlockLdsExtraM,
defaults::DefaultBBlockTransferThreadClusterLengths,
defaults::DefaultBBlockTransferThreadClusterArrangeOrder,
defaults::DefaultBBlockTransferSrcAccessOrder,
defaults::kBBlockTransferSrcVectorDim,
defaults::kBBlockTransferSrcScalarPerVector,
defaults::kBBlockTransferDstScalarPerVector_BK1,
defaults::kBBlockLdsExtraN,
defaults::kCShuffleMXdlPerWavePerShuffle,
defaults::kCShuffleNXdlPerWavePerShuffle,
defaults::DefaultCDEBlockTransferClusterLengths,
defaults::kCDEBlockTransferScalarPerVector_NPerBlock,
BlkGemmPipeSched,
BlkGemmPipelineVer,
ADataType,
BDataType,
defaults::kDirectLoad>;
// Test case helper for specialization testing
template <ck::tensor_operation::device::ConvolutionForwardSpecialization Spec>
using SpecializationTestInstance =
DeviceInstanceForTests_V3<2,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GNHWK,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
Spec>;
// Test case helper for layout testing (1D, 2D, 3D)
template <int NDim, typename ALayout, typename BLayout, typename ELayout>
using LayoutTestInstance = DeviceInstanceForTests_V3<NDim, ALayout, BLayout, ELayout>;
// Test case helper for data type testing
template <typename DataType, typename AccDataType = float>
using DataTypeTestInstance = DeviceInstanceForTests_V3<2,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GNHWK,
DataType,
DataType,
DataType,
AccDataType>;
// Test case helper for pipeline version testing
template <ck::BlockGemmPipelineVersion PipelineVer>
using PipelineVersionTestInstance = DeviceInstanceForTests_V3<
2,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GNHWK,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
ck::tensor_operation::device::GemmSpecialization::Default,
ck::BlockGemmPipelineScheduler::Intrawave,
PipelineVer>;
// Test case helper for pipeline scheduler testing
template <ck::BlockGemmPipelineScheduler Scheduler>
using PipelineSchedulerTestInstance = DeviceInstanceForTests_V3<
2,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GNHWK,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
ck::tensor_operation::device::GemmSpecialization::Default,
Scheduler>;
// Test case helper for GEMM padding testing
template <ck::tensor_operation::device::GemmSpecialization GemmSpec>
using GemmPaddingTestInstance = DeviceInstanceForTests_V3<
2,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GNHWK,
ck::half_t,
ck::half_t,
ck::half_t,
float,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::element_wise::PassThrough,
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
GemmSpec>;
// ============================================================================
// Test Enum Conversion Functions
// ============================================================================
TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineVersion)
{
using ck_tile::reflect::conv::convert_pipeline_version;
using enum ::ck::BlockGemmPipelineVersion;
using enum ::ck_tile::builder::PipelineVersion;
EXPECT_EQ(convert_pipeline_version<v1>(), V1);
EXPECT_EQ(convert_pipeline_version<v2>(), V2);
EXPECT_EQ(convert_pipeline_version<v3>(), V3);
EXPECT_EQ(convert_pipeline_version<v4>(), V4);
EXPECT_EQ(convert_pipeline_version<v5>(), V5);
}
TEST(InstanceToConvTraits, ConvertsPipelineVersion)
{
using ck_tile::reflect::conv::convert_pipeline_version;
using enum ck::PipelineVersion;
using enum PipelineVersion;
EXPECT_EQ(convert_pipeline_version<v1>(), V1);
EXPECT_EQ(convert_pipeline_version<v2>(), V2);
EXPECT_EQ(convert_pipeline_version<v4>(), V4);
EXPECT_EQ(convert_pipeline_version<weight_only>(), WEIGHT_ONLY);
}
TEST(InstanceToConvTraits, ConvertsBlockGemmPipelineScheduler)
{
using ck_tile::reflect::conv::convert_pipeline_scheduler;
using enum ck::BlockGemmPipelineScheduler;
using enum PipelineScheduler;
EXPECT_EQ(convert_pipeline_scheduler<Intrawave>(), INTRAWAVE);
EXPECT_EQ(convert_pipeline_scheduler<Interwave>(), INTERWAVE);
}
TEST(InstanceToConvTraits, ConvertsLoopScheduler)
{
using ck_tile::reflect::conv::convert_pipeline_scheduler;
using enum ck::LoopScheduler;
using enum PipelineScheduler;
EXPECT_EQ(convert_pipeline_scheduler<Default>(), DEFAULT);
EXPECT_EQ(convert_pipeline_scheduler<Interwave>(), INTERWAVE);
}
// ============================================================================
// Test Elementwise Operations
// ============================================================================
TEST(InstanceToConvTraits, ExtractsPassThroughOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::PassThrough>();
EXPECT_EQ(op, PASS_THROUGH);
}
TEST(InstanceToConvTraits, ExtractsScaleOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Scale>();
EXPECT_EQ(op, SCALE);
}
TEST(InstanceToConvTraits, ExtractsReluOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Relu>();
EXPECT_EQ(op, RELU);
}
TEST(InstanceToConvTraits, ExtractsGeluOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Gelu>();
EXPECT_EQ(op, GELU);
}
TEST(InstanceToConvTraits, ExtractsSigmoidOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Sigmoid>();
EXPECT_EQ(op, SIGMOID);
}
TEST(InstanceToConvTraits, ExtractsTanhOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::TanH>();
EXPECT_EQ(op, TANH);
}
TEST(InstanceToConvTraits, ExtractsScaleAddOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::ScaleAdd>();
EXPECT_EQ(op, SCALE_ADD);
}
TEST(InstanceToConvTraits, ExtractsSiluOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Silu>();
EXPECT_EQ(op, SILU);
}
TEST(InstanceToConvTraits, ExtractsSwishOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Swish>();
EXPECT_EQ(op, SWISH);
}
TEST(InstanceToConvTraits, ExtractsEluOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::Elu>();
EXPECT_EQ(op, ELU);
}
TEST(InstanceToConvTraits, ExtractsLeakyReluOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::LeakyRelu>();
EXPECT_EQ(op, LEAKY_RELU);
}
TEST(InstanceToConvTraits, ExtractsUnaryConvertOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::UnaryConvert>();
EXPECT_EQ(op, UNARY_CONVERT);
}
TEST(InstanceToConvTraits, ExtractsConvScaleOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::ConvScale>();
EXPECT_EQ(op, CONV_SCALE);
}
TEST(InstanceToConvTraits, ExtractsConvScaleAddOperation)
{
using enum ElementwiseOperation;
constexpr auto op =
ck_tile::reflect::conv::elementwise_op<ck::tensor_operation::element_wise::ConvScaleAdd>();
EXPECT_EQ(op, CONV_SCALE_ADD);
}
// ============================================================================
// Test Convolution Direction Detection
// ============================================================================
TEST(InstanceToConvTraits, DetectsForwardDirection)
{
using DeviceInstance = DeviceInstanceForTests_V3<>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
}
// ============================================================================
// Test Convolution Specialization Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsDefaultSpecialization)
{
using DeviceInstance = SpecializationTestInstance<
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::DEFAULT);
}
TEST(InstanceToConvTraits, ExtractsFilter1x1Pad0Specialization)
{
using DeviceInstance = SpecializationTestInstance<
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_1X1_PAD0);
}
TEST(InstanceToConvTraits, ExtractsFilter1x1Stride1Pad0Specialization)
{
using DeviceInstance = SpecializationTestInstance<
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.conv_specialization,
ck_tile::builder::ConvSpecialization::FILTER_1X1_STRIDE1_PAD0);
}
TEST(InstanceToConvTraits, ExtractsFilter3x3Specialization)
{
using DeviceInstance = SpecializationTestInstance<
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::FILTER_3x3);
}
TEST(InstanceToConvTraits, ExtractsOddCSpecialization)
{
using DeviceInstance = SpecializationTestInstance<
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.conv_specialization, ck_tile::builder::ConvSpecialization::ODD_C);
}
// ============================================================================
// Test 1D Convolution Layout Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsGnwcLayout)
{
using DeviceInstance = LayoutTestInstance<1,
ck::tensor_layout::convolution::GNWC,
ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::GNWK>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.spatial_dim, 1);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNWC, TensorLayout::GKXC, TensorLayout::GNWK));
}
TEST(InstanceToConvTraits, ExtractsNwgcLayout)
{
using DeviceInstance = LayoutTestInstance<1,
ck::tensor_layout::convolution::NWGC,
ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::NWGK>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.spatial_dim, 1);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NWGC, TensorLayout::GKXC, TensorLayout::NWGK));
}
TEST(InstanceToConvTraits, ExtractsNgcwLayout)
{
using DeviceInstance = LayoutTestInstance<1,
ck::tensor_layout::convolution::NGCW,
ck::tensor_layout::convolution::GKXC,
ck::tensor_layout::convolution::NGKW>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.spatial_dim, 1);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NGCW, TensorLayout::GKXC, TensorLayout::NGKW));
}
// ============================================================================
// Test 2D Convolution Layout Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsGnhwcLayout)
{
using DeviceInstance = LayoutTestInstance<2,
ck::tensor_layout::convolution::GNHWC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::GNHWK>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNHWC, TensorLayout::GKYXC, TensorLayout::GNHWK));
}
TEST(InstanceToConvTraits, ExtractsNhwgcLayout)
{
using DeviceInstance = LayoutTestInstance<2,
ck::tensor_layout::convolution::NHWGC,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::NHWGK>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NHWGC, TensorLayout::GKYXC, TensorLayout::NHWGK));
}
TEST(InstanceToConvTraits, ExtractsNgchwGkyxcLayout)
{
using DeviceInstance = LayoutTestInstance<2,
ck::tensor_layout::convolution::NGCHW,
ck::tensor_layout::convolution::GKYXC,
ck::tensor_layout::convolution::NGKHW>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NGCHW, TensorLayout::GKYXC, TensorLayout::NGKHW));
}
TEST(InstanceToConvTraits, ExtractsNgchwGkcyxLayout)
{
using DeviceInstance = LayoutTestInstance<2,
ck::tensor_layout::convolution::NGCHW,
ck::tensor_layout::convolution::GKCYX,
ck::tensor_layout::convolution::NGKHW>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NGCHW, TensorLayout::GKCYX, TensorLayout::NGKHW));
}
// ============================================================================
// Test 3D Convolution Layout Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsGndhwcLayout)
{
using DeviceInstance = LayoutTestInstance<3,
ck::tensor_layout::convolution::GNDHWC,
ck::tensor_layout::convolution::GKZYXC,
ck::tensor_layout::convolution::GNDHWK>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.spatial_dim, 3);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::GNDHWC, TensorLayout::GKZYXC, TensorLayout::GNDHWK));
}
TEST(InstanceToConvTraits, ExtractsNdhwgcLayout)
{
using DeviceInstance = LayoutTestInstance<3,
ck::tensor_layout::convolution::NDHWGC,
ck::tensor_layout::convolution::GKZYXC,
ck::tensor_layout::convolution::NDHWGK>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.spatial_dim, 3);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NDHWGC, TensorLayout::GKZYXC, TensorLayout::NDHWGK));
}
TEST(InstanceToConvTraits, ExtractsNgcdhwLayout)
{
using DeviceInstance = LayoutTestInstance<3,
ck::tensor_layout::convolution::NGCDHW,
ck::tensor_layout::convolution::GKZYXC,
ck::tensor_layout::convolution::NGKDHW>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.spatial_dim, 3);
EXPECT_THAT(traits.layout,
ElementsAre(TensorLayout::NGCDHW, TensorLayout::GKZYXC, TensorLayout::NGKDHW));
}
// ============================================================================
// Test Data Type Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsFp16DataType)
{
using DeviceInstance = DataTypeTestInstance<ck::half_t>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.data_type, DataType::FP16);
}
TEST(InstanceToConvTraits, ExtractsBf16DataType)
{
using DeviceInstance = DataTypeTestInstance<ck::bhalf_t>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.data_type, DataType::BF16);
}
TEST(InstanceToConvTraits, ExtractsFp32DataType)
{
using DeviceInstance = DataTypeTestInstance<float, float>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.data_type, DataType::FP32);
}
TEST(InstanceToConvTraits, ExtractsI8DataType)
{
using DeviceInstance = DataTypeTestInstance<int8_t, int32_t>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.data_type, DataType::I8);
}
// ============================================================================
// Test Pipeline Version Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsPipelineV2)
{
using DeviceInstance = PipelineVersionTestInstance<ck::BlockGemmPipelineVersion::v2>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V2);
}
TEST(InstanceToConvTraits, ExtractsPipelineV3)
{
using DeviceInstance = PipelineVersionTestInstance<ck::BlockGemmPipelineVersion::v3>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V3);
}
TEST(InstanceToConvTraits, ExtractsInterwaveScheduler)
{
using DeviceInstance = PipelineSchedulerTestInstance<ck::BlockGemmPipelineScheduler::Interwave>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTERWAVE);
}
// ============================================================================
// Test GEMM Padding Detection
// ============================================================================
TEST(InstanceToConvTraits, ExtractsDefaultGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::Default>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT);
}
TEST(InstanceToConvTraits, ExtractsMnkGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MNKPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MNK_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::M_PADDING);
}
TEST(InstanceToConvTraits, ExtractsNPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::NPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::N_PADDING);
}
TEST(InstanceToConvTraits, ExtractsKPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::KPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::K_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMnPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MNPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MN_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMkPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MKPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MK_PADDING);
}
TEST(InstanceToConvTraits, ExtractsNkPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::NKPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::NK_PADDING);
}
TEST(InstanceToConvTraits, ExtractsOPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::OPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::O_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MO_PADDING);
}
TEST(InstanceToConvTraits, ExtractsNoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::NOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::NO_PADDING);
}
TEST(InstanceToConvTraits, ExtractsKoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::KOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::KO_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMnoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MNOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MNO_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMkoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MKOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MKO_PADDING);
}
TEST(InstanceToConvTraits, ExtractsNkoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::NKOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::NKO_PADDING);
}
TEST(InstanceToConvTraits, ExtractsMnkoPaddingGemmPadding)
{
using DeviceInstance =
GemmPaddingTestInstance<ck::tensor_operation::device::GemmSpecialization::MNKOPadding>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
EXPECT_EQ(traits.gemm_padding, GemmPadding::MNKO_PADDING);
}
} // anonymous namespace

View File

@@ -0,0 +1,262 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
// ============================================================================
// Unit Tests for Complete Device Instance Transformations
// ============================================================================
//
// PURPOSE:
// --------
// These tests verify the complete instance_to_conv_traits transformation
// for entire Device class templates. Each test validates that all traits
// are correctly extracted from a specific Device class instantiation.
//
// TEST COVERAGE:
// --------------
// Complete transformation verification for each XDL Device class template:
// 1. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
// 2. DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
// 3. DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
//
// Each test verifies:
// - Spatial dimension extraction
// - Convolution direction
// - Data type detection
// - GEMM padding configuration
// - Tile dimensions (M, N, K per block)
// - Pipeline scheduler and version
// ============================================================================
#include <gtest/gtest.h>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck_tile/builder/reflect/instance_to_conv_traits.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/reflect/instance_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp>
namespace {
using ::ck_tile::builder::ConvDirection;
using ::ck_tile::builder::DataType;
using ::ck_tile::builder::GemmPadding;
using ::ck_tile::builder::PipelineScheduler;
using ::ck_tile::builder::PipelineVersion;
// ============================================================================
// Comprehensive Transformation Tests - Per Device Class Template
// ============================================================================
// These tests verify the complete InstanceTraits → ConvTraits transformation
// for each forward convolution Device class template.
// ============================================================================
TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffleV3)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
ck::tensor_operation::device::GemmSpecialization::Default,
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::BlockGemmPipelineScheduler::Intrawave, // BlkGemmPipeSched
ck::BlockGemmPipelineVersion::v1, // BlkGemmPipelineVer
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
false>; // DirectLoad
using InstTraits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT);
// Verify tile dimensions
EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock);
EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock);
EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock);
// Verify pipeline configuration
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::INTRAWAVE);
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
}
TEST(InstanceToConvTraits, TransformsFwdMultipleAbdXdlCShuffle)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
ck::tensor_operation::device::GemmSpecialization::Default,
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default, // LoopSched
1>; // NumGroupsToMerge
using InstTraits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT);
// Verify tile dimensions
EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock);
EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock);
EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock);
// Verify pipeline configuration (uses LoopScheduler instead of BlockGemmPipelineScheduler)
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
}
TEST(InstanceToConvTraits, TransformsFwdMultipleDXdlLargeTensor)
{
using DeviceInstance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor<
2, // NDimSpatial
ck::tensor_layout::convolution::GNHWC, // ALayout
ck::tensor_layout::convolution::GKYXC, // BLayout
ck::Tuple<>, // DsLayout
ck::tensor_layout::convolution::GNHWK, // ELayout
ck::half_t, // ADataType
ck::half_t, // BDataType
float, // AccDataType
ck::half_t, // CShuffleDataType
ck::Tuple<>, // DsDataType
ck::half_t, // EDataType
ck::tensor_operation::element_wise::PassThrough, // AElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // BElementwiseOperation
ck::tensor_operation::element_wise::PassThrough, // CDEElementwiseOperation
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
ck::tensor_operation::device::GemmSpecialization::Default,
1, // NumGemmKPrefetchStage
256, // BlockSize
128, // MPerBlock
128, // NPerBlock
16, // KPerBlock
8, // AK1
8, // BK1
32, // MPerXDL
32, // NPerXDL
4, // MXdlPerWave
4, // NXdlPerWave
ck::Sequence<4, 64, 1>, // ABlockTransferThreadClusterLengths
ck::Sequence<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
8, // ABlockTransferSrcScalarPerVector
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
ck::Sequence<4, 64, 1>, // BBlockTransferThreadClusterLengths
ck::Sequence<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
ck::Sequence<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
8, // BBlockTransferSrcScalarPerVector
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1, // CShuffleMXdlPerWavePerShuffle
1, // CShuffleNXdlPerWavePerShuffle
ck::Sequence<1, 32, 1, 8>, // CDEBlockTransferClusterLengths
8, // CDEBlockTransferScalarPerVector_NPerBlock
ck::half_t, // AComputeDataType
ck::half_t, // BComputeDataType
ck::LoopScheduler::Default>; // LoopSched
using InstTraits = ck_tile::reflect::InstanceTraits<DeviceInstance>;
const auto traits = ck_tile::reflect::conv::instance_to_conv_traits<DeviceInstance>();
// Verify signature information
EXPECT_EQ(traits.spatial_dim, InstTraits::kSpatialDim);
EXPECT_EQ(traits.direction, ConvDirection::FORWARD);
EXPECT_EQ(traits.data_type, DataType::FP16);
EXPECT_EQ(traits.gemm_padding, GemmPadding::DEFAULT);
// Verify tile dimensions
EXPECT_EQ(traits.tile_dims.m, InstTraits::kMPerBlock);
EXPECT_EQ(traits.tile_dims.n, InstTraits::kNPerBlock);
EXPECT_EQ(traits.tile_dims.k, InstTraits::kKPerBlock);
// Verify pipeline configuration
EXPECT_EQ(traits.pipeline_scheduler, PipelineScheduler::DEFAULT);
EXPECT_EQ(traits.pipeline_version, PipelineVersion::V1);
}
} // anonymous namespace