[CK_BUILDER] First fwd convolution builder implementation (#3070)

* Add experimental builder infrastructure for composable_kernel

- Add experimental/builder directory with README documentation.
- Create initial test infrastructure with CMakeLists.txt and placeholder test.
- Update root CMakeLists.txt to support CK_EXPERIMENTAL_BUILDER option.
- Update .gitignore to not treat `experimental/builder` as a CMake build directory.

This establishes the directory structure  for a high-level builder pattern that will provide a semantically-clear interface for constructing CK operations, with initial focus on convolution kernels for MIOpen integration.

* Fix clang formatting.

* Fix CMake build infrastructure for experimental builder

- Add experimental/builder CMakeLists.txt with proper subdirectory structure
- Add placeholder include/ck_tile/builder CMakeLists.txt for header installation
- Fix gtest.cmake to use include_guard to prevent multiple inclusions
- Update root CMakeLists.txt to include full builder directory instead of just tests

* Scope C++20 settingto the test code

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Remove redundant GTest::gtest linkage

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Introduce basic types, and convolution algorithm concepts and limits.

* Add convolution signature concepts.

* Add convolution factory.

* Finalize conv factory implementation for fwd convolutions.

* Add type definitions for testing.

* Add placeholder test.

* Add convolution builder definition.

* Fully functional fwd conv builder.

* Test improvements.

* Clean-up include headers.

* Enable the limit checks for the convolution algorithm parameters.

* Remove dead code.

* clang formatting.

* Add more tests and missing conv specialization argument.

* clang formatting.

* Add explicit handling of the tensor layouts.

* Add complete 2D/3D layout support to CK Builder

  - Add missing 2D layouts: GNHWC_GKYXC_GNHWK, NGCHW_GKCYX_NGKHW
  - Add missing 3D layout: GNDHWC_GKZYXC_GNDHWK
  - Add 1D layouts (NWGC, NGCW, GNWC, NGCW_GKCX) for future support
  - Add 3 tests for new 2D/3D layouts
  - All tests pass (5/5)

* Add tests for remaining 2D/3D layouts

  - Add test for 2D NGCHW_GKYXC_NGKHW (channels-first) with Filter1x1Stride1Pad0
  - Add test for 3D NDHWGC_GKZYXC_NDHWGK (channels-last)
  - All 7 tests pass (complete coverage for all 2D/3D forward layouts)

* Change enum converters to consteval.

* 7 tests with pipeline and specialization| Test # | Dim | Type | Layout               | Pipeline | Specialization          |
  |--------|-----|------|----------------------|----------|-------------------------|
  | 1      | 2D  | BF16 | NHWGC_GKYXC_NHWGK    | V1       | DEFAULT                 |
  | 2      | 2D  | FP16 | GNHWC_GKYXC_GNHWK    | V3       | FILTER_1X1_PAD0         |
  | 3      | 2D  | FP32 | NGCHW_GKCYX_NGKHW    | V4       | FILTER_1X1_STRIDE1_PAD0 |
  | 4      | 2D  | BF16 | NHWGC_GKYXC_NHWGK    | V5       | FILTER_3x3              |
  | 5      | 3D  | FP32 | NGCDHW_GKCZYX_NGKDHW | V1       | FILTER_1X1_PAD0         |
  | 6      | 3D  | BF16 | GNDHWC_GKZYXC_GNDHWK | V3       | DEFAULT                 |
  | 7      | 3D  | FP16 | NDHWGC_GKZYXC_NDHWGK | V4       | FILTER_1X1_PAD0         |

* Add missing convolution layouts and provide better compile-time error in instance traits.

* Fix clang formatting.

* Changed I8 -> S8.

* Fix signature.

* Rename concepts and corresponding members.

* Rename LDS related parameters.

* Remove ODD_C specialization. Add V2 pipeline.

* Add missing types.

* Add elementwise operation to the conv signature.

* Improve compile-time error message for unsupported elementwise ops.

* Separate different fwd conv builder tests into separate compilation units.

* Fix layout to string and add name to old CK PassThrough elementwise op.

* Enable both CK and CK Tile tensor layouts in instance traits.

* Fix clang-format.

---------

Co-authored-by: John Shumway <jshumway@amd.com>
Co-authored-by: John Shumway <john.shumwayjr@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: JH-Leon-KIM-AMD <jeonghyun.kim@amd.com>
This commit is contained in:
Ville Pietilä
2025-10-27 20:09:24 +02:00
committed by GitHub
parent 5c1974065e
commit 6c2ca1211a
21 changed files with 1527 additions and 3 deletions

View File

@@ -23,9 +23,18 @@ This project is a prototype for a more general builder pattern for all of compos
To enable the experimental builder, configure your build with:
```sh
cmake -DCK_EXPERIMENTAL_BUILDER=ON -DCMAKE_CXX_STANDARD=20 ...
```bash
cmake \
-D CMAKE_PREFIX_PATH=/opt/rocm \
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-D CMAKE_BUILD_TYPE=Release \
-D GPU_TARGETS="gfx942;gfx950" \
-D CK_EXPERIMENTAL_BUILDER=ON \
-D CMAKE_CXX_STANDARD=20 \
-G Ninja \
..
```
## Building and testing
During development, build and test from the CK build directory with

View File

@@ -0,0 +1,143 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/sequence.hpp"
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
// Convert a static array to a sequence
// Usage example:
// static constexpr std::vector arr {1, 2, 3};
// using seq = to_sequence_v<arr>; // seq is ck::Sequence<1, 2, 3>
template <typename T, const T& Arr>
struct to_sequence_t
{
private:
template <std::size_t... Is>
static auto get_sequence_type(std::index_sequence<Is...>) -> ck::Sequence<Arr[Is]...>;
// Helper method to handler the unusual .Size() method name in ck::Array.
static constexpr auto get_size(const auto& arr)
{
if constexpr(requires { arr.size(); })
{
return arr.size();
}
else
{
return arr.Size();
}
}
public:
using value = decltype(get_sequence_type(std::make_index_sequence<get_size(Arr)>{}));
};
template <auto& Arr>
using to_sequence_v = typename to_sequence_t<std::remove_cvref_t<decltype(Arr)>, Arr>::value;
// Wrapper function to make constexpr strings a structural type for NTTP.
template <size_t N>
struct StringLiteral
{
char data[N];
constexpr StringLiteral(const char (&str)[N])
{
for(size_t i = 0; i < N; ++i)
data[i] = str[i];
}
constexpr bool operator==(const StringLiteral<N>& other) const
{
for(size_t i = 0; i < N; ++i)
{
if(data[i] != other.data[i])
{
return false;
}
}
return true;
}
};
// This is a C++17 deduction guide. It allows the compiler to automatically
// deduce the template argument `N` for `StringLiteral` from a string literal
// constructor argument. For example, you can write `StringLiteral s{"foo"};`
// instead of `StringLiteral<4> s{"foo"};`.
template <size_t N>
StringLiteral(const char (&)[N]) -> StringLiteral<N>;
// Helper to provide a readable error for unsupported enum values.
// The compiler will print the name of this struct in the error message, so
// the name of the enum value will appear instead of just its integer value.
template <auto T>
struct UnsupportedEnumValue
{
};
// Helper functions to convert enums to strings
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
{
switch(dir)
{
case ConvDirection::FORWARD: return "Forward";
case ConvDirection::BACKWARD_DATA: return "Backward Data";
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
default: return "Unknown";
}
}
constexpr std::string_view DataTypeToString(DataType dt)
{
switch(dt)
{
case DataType::FP16: return "FP16";
case DataType::FP32: return "FP32";
case DataType::BF16: return "BF16";
case DataType::FP8: return "FP8";
case DataType::I8: return "I8";
case DataType::U8: return "U8";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
{
switch(layout)
{
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
{
switch(layout)
{
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
default: return "Unknown";
}
}
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
{
switch(layout)
{
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
default: return "Unknown";
}
}
} // namespace ck_tile::builder

View File

@@ -0,0 +1,141 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <type_traits>
#include <concepts>
#include <array>
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
/********************************************************************/
/* Descriptors for individual elements of the algorithm description */
/********************************************************************/
// Concept for thread block dimensions for a GEMM problem.
template <typename T>
concept ThreadBlockDescriptor = requires(T t) {
{ t.block_size } -> std::convertible_to<size_t>;
{ t.tile_size.m } -> std::convertible_to<size_t>;
{ t.tile_size.n } -> std::convertible_to<size_t>;
{ t.tile_size.k } -> std::convertible_to<size_t>;
};
// Concept for parameters that describe a gridwise GEMM problem.
template <typename T>
concept GridwiseGemmDescriptor = requires(T t) {
{ t.ak1 } -> std::convertible_to<size_t>;
{ t.bk1 } -> std::convertible_to<size_t>;
{ t.m_per_xdl } -> std::convertible_to<size_t>;
{ t.n_per_xdl } -> std::convertible_to<size_t>;
{ t.m_xdl_per_wave } -> std::convertible_to<size_t>;
{ t.n_xdl_per_wave } -> std::convertible_to<size_t>;
};
// Concept for vectorized data transfer for convolution input tensors.
template <typename T>
concept BlockTransferDescriptor = requires(T t) {
{ t.k0 } -> std::convertible_to<size_t>;
{ t.m_n } -> std::convertible_to<size_t>;
{ t.k1 } -> std::convertible_to<size_t>;
};
// Concept for thread cluster dimensions for GEMM output tensor.
template <typename T>
concept ThreadClusterDescriptor = requires(T t) {
{ t.m_block } -> std::convertible_to<size_t>;
{ t.m_wave_per_xdl } -> std::convertible_to<size_t>;
{ t.n_block } -> std::convertible_to<size_t>;
{ t.n_wave_per_xdl } -> std::convertible_to<size_t>;
};
// Concept for the LDS transfer for the convolution input tensors.
template <typename T>
concept LdsTransferDescriptor = requires(T t) {
{ t.src_vector_dim } -> std::convertible_to<size_t>;
{ t.src_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.lds_dst_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.is_direct_load } -> std::convertible_to<bool>;
{ t.lds_padding } -> std::convertible_to<bool>;
};
// Concept for the convolution output tensor epilogue (copy from registers to global memory via
// LDS).
template <typename T>
concept EpilogueDescriptor = requires(T t) {
{ t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
{ t.n_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
{ t.scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept for the thread cluster access order
template <typename T>
concept AccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
};
// No requirements yet for a ConvAlogorithm concept.
template <typename T>
concept ConvAlgorithmDescriptor = std::is_class_v<T>;
/******************************************** */
/* Requirements for the algorithm description */
/******************************************** */
// Concept to check if struct specifies thread block info.
template <typename T>
concept SpecifiesThreadBlock = requires {
{ T::thread_block } -> ThreadBlockDescriptor;
};
// Concept to check if a struct specifies gridwise GEMM info.
template <typename T>
concept SpecifiesGridwiseGemm = requires {
{ T::gridwise_gemm } -> GridwiseGemmDescriptor;
};
// Concept to check if a struct specifies convolution input and output block transfer info.
template <typename T>
concept SpecifiesBlockTransfer = requires(T t) {
{ T::block_transfer.block_transfer_a } -> BlockTransferDescriptor;
{ T::block_transfer.block_transfer_b } -> BlockTransferDescriptor;
{ T::block_transfer.thread_cluster_dims_c } -> ThreadClusterDescriptor;
};
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
template <typename T>
concept SpecifiesLdsTransfer = requires(T t) {
{ T::block_transfer.lds_transfer_a } -> LdsTransferDescriptor;
{ T::block_transfer.lds_transfer_b } -> LdsTransferDescriptor;
{ T::block_transfer.epilogue_c } -> EpilogueDescriptor;
};
// Concept to check if a struct specifies thread cluster access order info.
template <typename T>
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
{ T::block_transfer.block_transfer_access_order_a } -> AccessOrderDescriptor;
{ T::block_transfer.block_transfer_access_order_b } -> AccessOrderDescriptor;
};
// Concept to check if a struct specifies source access order info.
template <typename T>
concept SpecifiesSourceAccessOrder = requires(T t) {
{ T::block_transfer.src_access_order_a } -> AccessOrderDescriptor;
{ T::block_transfer.src_access_order_b } -> AccessOrderDescriptor;
};
// Concept to check if struct specifies block_gemm_pipeline_version.
template <typename T>
concept SpecifiesGemmPipelineVersion = requires {
{ T::pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
};
template <typename T>
concept SpecifiesFwdConcSpecialization = requires {
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <type_traits>
#include <concepts>
namespace ck_tile::builder {
// Limits for input vector transfer.
template <auto Value>
concept InputVectorTransferLimits = requires {
requires Value.src_vector_dim > 0 && Value.src_scalar_per_vector > 0 &&
Value.lds_dst_scalar_per_vector > 0;
};
// Limits for output vector transfer.
template <auto Value>
concept OutputVectorTransferLimits = requires {
requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 &&
Value.n_xdl_per_wave_per_shuffle > 0;
};
// Limits for access order. Must be a permutation of {0, 1, 2}.
template <auto Value>
concept AccessOrderLimits = requires {
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) &&
(Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) &&
(Value[2] >= 0 && Value[2] < 3));
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,38 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/conv_factory.hpp"
#include "ck_tile/builder/versions.hpp"
namespace ck_tile::builder {
/**
* @brief Top-level builder for creating convolution kernel instances.
*
* This struct serves as the main entry point for generating a convolution kernel.
* It uses a factory pattern based on the provided signature, algorithm, and version
* to construct the appropriate kernel instance.
*
* @tparam SIGNATURE The convolution signature, which describes the mathematical functionality of
* the algorithm (e.g., data types, layouts, direction).
* @tparam ALGORITHM The specific convolution algorithm to be used for the implementation.
* @tparam VERSION The version of the builder implementation.
*/
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION = LATEST_API_VERSION>
requires SupportedVersion<VERSION> && ValidConvSignature<SIGNATURE>
struct ConvBuilder
{
static constexpr auto kVersion = VERSION;
using Factory = ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
// Output: The kernel class.
using Instance = Factory::Instance;
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,539 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// A factory for instantiating CK convolution kernels.
//
// This file translates a semantic description of a convolution operation
// (`ConvSignatureDescriptor` and `ConvAlgorithmDescriptor`) into specific,
// low-level template arguments required by the underlying CK device-level
// kernel implementations. This abstraction enables more complex build
// time logic and simplifies the kernel specification.
//
// Key Components:
//
// Template Metaprogram:
// - ConvFactory: The main factory, with specializations for different
// convolution directions (currently only forward).
//
// Template Metaprogram Helpers:
// - ConvTensorLayouts: Maps layout enums to CK layout types for different
// spatial dimensions (2D/3D) and directions.
// - ConvTensorTypes: Maps data type enums (FP16, BF16, FP32) to C++ types used by CK.
// - ConvPassThroughOps: Hard-coded pass-through element-wise operations.
// - ConvSpec: Encapsulates convolution and GEMM specialization enums.
//
// `constexpr` Helper Functions:
// - SetThreadBlockInfo: Determines thread block dimensions and tile sizes.
// - SetConvTuningInfo: Sets XDL and AK1/BK1 tuning parameters.
// - SetFwdConvABlockTransfer: Configures A tensor block transfer parameters.
// - SetFwdConvBBlockTransfer: Configures B tensor block transfer parameters.
// - SetCBlockTransfer: Configures C tensor block transfer parameters.
// - SetBlockGemmPipelineVersion: Maps pipeline version enum to CK types.
//
// The primary entry point is the `ConvFactory` struct, which is currently
// specialized for forward convolutions and produces instances of
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3.
#pragma once
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/types.hpp"
#include "ck_tile/builder/versions.hpp"
namespace ck_tile::builder::factory_internal {
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
template <auto LayoutValue, size_t SPATIAL_DIM, ConvDirection DIR>
requires(ConvSpatialDim<SPATIAL_DIM> && ValidConvLayoutForSpatialDim<LayoutValue, SPATIAL_DIM>)
struct ConvTensorLayouts
{
// This will trigger if a specialization for the given layout is not found.
// We should always catch this in an earlier validation check.
using Layout = decltype(LayoutValue);
static_assert(sizeof(Layout) == 0,
"Internal error. Unsupported layout for convolution factory.");
};
// 1D Forward Convolution Layout Specializations
template <>
struct ConvTensorLayouts<GroupConvLayout1D::NWGC_GKXC_NWGK, 1, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NWGC;
using BLayout = ck::tensor_layout::convolution::GKXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NWGK;
};
template <>
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKXC_NGKW, 1, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NGCW;
using BLayout = ck::tensor_layout::convolution::GKXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NGKW;
};
template <>
struct ConvTensorLayouts<GroupConvLayout1D::GNWC_GKXC_GNWK, 1, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::GNWC;
using BLayout = ck::tensor_layout::convolution::GKXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::GNWK;
};
template <>
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKCX_NGKW, 1, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NGCW;
using BLayout = ck::tensor_layout::convolution::GKCX;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NGKW;
};
template <>
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKYXC_NGKHW, 2, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NGCHW;
using BLayout = ck::tensor_layout::convolution::GKYXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NGKHW;
};
template <>
struct ConvTensorLayouts<GroupConvLayout2D::NHWGC_GKYXC_NHWGK, 2, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NHWGC;
using BLayout = ck::tensor_layout::convolution::GKYXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NHWGK;
};
template <>
struct ConvTensorLayouts<GroupConvLayout2D::GNHWC_GKYXC_GNHWK, 2, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::GNHWC;
using BLayout = ck::tensor_layout::convolution::GKYXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::GNHWK;
};
template <>
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKCYX_NGKHW, 2, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NGCHW;
using BLayout = ck::tensor_layout::convolution::GKCYX;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NGKHW;
};
template <>
struct ConvTensorLayouts<GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, 3, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NGCDHW;
using BLayout = ck::tensor_layout::convolution::GKCZYX;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NGKDHW;
};
template <>
struct ConvTensorLayouts<GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, 3, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::NDHWGC;
using BLayout = ck::tensor_layout::convolution::GKZYXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::NDHWGK;
};
template <>
struct ConvTensorLayouts<GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, ConvDirection::FORWARD>
{
using ALayout = ck::tensor_layout::convolution::GNDHWC;
using BLayout = ck::tensor_layout::convolution::GKZYXC;
using DsLayout = ck::Tuple<>;
using ELayout = ck::tensor_layout::convolution::GNDHWK;
};
// Type mappings from builder convolution data type to CK tensor types.
template <DataType T>
struct ConvTensorTypes
{
// This will trigger if a specialization for the given DataType is not found.
// We should always catch this in an earlier validation check.
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
"Internal error. Unsupported data type for convolution factory.");
};
template <>
struct ConvTensorTypes<DataType::FP16>
{
using ADataType = ck::half_t;
using BDataType = ck::half_t;
using CShuffleDataType = ck::half_t;
using DsDataTypes = ck::Tuple<>;
using AccDataType = float;
using EDataType = ck::half_t;
};
template <>
struct ConvTensorTypes<DataType::BF16>
{
using ADataType = ck::bhalf_t;
using BDataType = ck::bhalf_t;
using CShuffleDataType = ck::bhalf_t;
using DsDataTypes = ck::Tuple<>;
using AccDataType = float;
using EDataType = ck::bhalf_t;
};
template <>
struct ConvTensorTypes<DataType::FP32>
{
using ADataType = float;
using BDataType = float;
using CShuffleDataType = float;
using DsDataTypes = ck::Tuple<>;
using AccDataType = float;
using EDataType = float;
};
template <ElementwiseOperation T>
struct ElementwiseOps
{
// This will trigger if a specialization for the given DataType is not found.
// We should always catch this in an earlier validation check.
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
"Internal error. Unsupported elementwise operation for convolution factory.");
};
template <>
struct ElementwiseOps<ElementwiseOperation::PASS_THROUGH>
{
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
};
// The algorithm specializations for the convolution and GEMM.
template <typename CONV_ENUM>
requires(
std::is_same_v<CONV_ENUM, ck::tensor_operation::device::ConvolutionForwardSpecialization>)
struct ConvSpec
{
CONV_ENUM conv_spec;
ck::tensor_operation::device::GemmSpecialization gemm_spec;
};
// Deduction guide for ConvSpec to simplify brace initialization.
template <typename CONV_ENUM, typename GEMM_ENUM>
ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec<CONV_ENUM>;
// Block info for a convolution.
struct MNK
{
size_t m{};
size_t n{};
size_t k{};
};
struct ConvBlock
{
size_t block_size = 0;
MNK per_block = {};
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr ConvBlock SetThreadBlockInfo()
{
constexpr auto& TB = ALGORITHM.thread_block;
return ConvBlock{.block_size = TB.block_size,
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}};
}
// Convolution tuning parameters.
struct GridwiseGemm
{
size_t ak1 = 0;
size_t bk1 = 0;
size_t m_per_xdl = 0;
size_t n_per_xdl = 0;
size_t m_xdl_per_wave = 0;
size_t n_xdl_per_wave = 0;
};
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
constexpr GridwiseGemm SetGridwiseGemmInfo()
{
constexpr auto& TP = ALGORITHM.gridwise_gemm;
return GridwiseGemm{
.ak1 = TP.ak1,
.bk1 = TP.bk1,
.m_per_xdl = TP.m_per_xdl,
.n_per_xdl = TP.n_per_xdl,
.m_xdl_per_wave = TP.m_xdl_per_wave,
.n_xdl_per_wave = TP.n_xdl_per_wave,
};
}
// Block transfer parameters for A or B tensor.
struct BlockTransfer
{
ck::Array<size_t, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
ck::Array<size_t, 3> thread_cluster_order = {0, 0, 0};
ck::Array<size_t, 3> src_access_order = {0, 0, 0};
size_t src_vector_dim = 0;
size_t src_scalar_per_vector = 0;
size_t lds_dst_scalar_per_vector = 0;
bool is_direct_load = false;
bool lds_padding = false;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr BlockTransfer SetFwdConvABlockTransfer()
{
constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_a;
constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_a;
constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_a;
constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_a;
BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1},
.thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]},
.src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]},
.src_vector_dim = LDS.src_vector_dim,
.src_scalar_per_vector = LDS.src_scalar_per_vector,
.lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector,
.is_direct_load = LDS.is_direct_load,
.lds_padding = LDS.lds_padding};
return block_transfer;
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr BlockTransfer SetFwdConvBBlockTransfer()
{
constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_b;
constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_b;
constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_b;
constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_b;
BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1},
.thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]},
.src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]},
.src_vector_dim = LDS.src_vector_dim,
.src_scalar_per_vector = LDS.src_scalar_per_vector,
.lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector,
.is_direct_load = LDS.is_direct_load,
.lds_padding = LDS.lds_padding};
return block_transfer;
}
// Block transfer parameters for C tensor.
struct CBlockTransfer
{
size_t m_xdl_per_wave_per_shuffle = 0;
size_t n_xdl_per_wave_per_shuffle = 0;
ck::Array<size_t, 4> thread_cluster_dims = {0, 0, 0, 0};
size_t scalar_per_vector = 0;
};
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
constexpr CBlockTransfer SetCBlockTransfer()
{
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_c;
constexpr auto& EPC = ALGORITHM.block_transfer.epilogue_c;
CBlockTransfer block_transfer{.m_xdl_per_wave_per_shuffle = EPC.m_xdl_per_wave_per_shuffle,
.n_xdl_per_wave_per_shuffle = EPC.n_xdl_per_wave_per_shuffle,
.thread_cluster_dims =
{
TCL.m_block,
TCL.m_wave_per_xdl,
TCL.n_block,
TCL.n_wave_per_xdl,
},
.scalar_per_vector = EPC.scalar_per_vector};
return block_transfer;
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
{
constexpr auto version = ALGORITHM.pipeline_version;
if constexpr(version == BlockGemmPipelineVersion::V1)
{
return ck::BlockGemmPipelineVersion::v1;
}
else if constexpr(version == BlockGemmPipelineVersion::V3)
{
return ck::BlockGemmPipelineVersion::v3;
}
else if constexpr(version == BlockGemmPipelineVersion::V4)
{
return ck::BlockGemmPipelineVersion::v4;
}
else if constexpr(version == BlockGemmPipelineVersion::V5)
{
return ck::BlockGemmPipelineVersion::v5;
}
else
{
static_assert(false, "Unknown BlockGemmPipelineVersion");
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization()
{
constexpr auto specialization = ALGORITHM.fwd_specialization;
if constexpr(specialization == ConvFwdSpecialization::DEFAULT)
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
}
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
}
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
}
else if constexpr(specialization == ConvFwdSpecialization::FILTER_3x3)
{
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3;
}
else
{
static_assert(false, "Unknown ConvFwdSpecialization");
}
}
} // namespace ck_tile::builder::factory_internal
namespace ck_tile::builder {
// Primary template for the convolution factory.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
auto VERSION>
struct ConvFactory;
// Factory specialization for an instance of a grouped forward convolution kernel.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
requires ConvDirectionIsForward<SIGNATURE>
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts =
factory_internal::ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
using AlgorithmType = decltype(ALGORITHM);
// Check preconditions for the algorithm description.
static_assert(SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
"Only 2D and 3D convolutions are supported in this factory.");
static_assert(SpecifiesThreadBlock<AlgorithmType>,
"The convolution algorithm descriptor must specify thread block info.");
static_assert(SpecifiesGridwiseGemm<AlgorithmType>,
"The convolution algorithm descriptor must specify gridwise GEMM info.");
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify block transfer info.");
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
"The convolution algorithm descriptor must specify LDS transfer info.");
static_assert(
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify thread cluster access order info.");
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
"The convolution algorithm descriptor must specify source access order info.");
static_assert(SpecifiesGemmPipelineVersion<AlgorithmType>,
"The convolution algorithm descriptor must specify block gemm pipeline version.");
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
"The convolution algorithm descriptor must specify forward convolution "
"specialization.");
static constexpr auto FWD_CONV_SPECIALIZATION =
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr factory_internal::ConvSpec SPECIALIZATION{
.conv_spec = FWD_CONV_SPECIALIZATION,
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
};
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM =
factory_internal::SetGridwiseGemmInfo<SIGNATURE, ALGORITHM>();
static constexpr auto A_BLOCK_TRANSFER =
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
static constexpr auto B_BLOCK_TRANSFER =
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
static constexpr auto C_BLOCK_TRANSFER =
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto PIPELINE_VERSION =
factory_internal::SetBlockGemmPipelineVersion<ALGORITHM>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
// The forward convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataType,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
GRIDWISE_GEMM.m_per_xdl,
GRIDWISE_GEMM.n_per_xdl,
GRIDWISE_GEMM.m_xdl_per_wave,
GRIDWISE_GEMM.n_xdl_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scalar_per_vector,
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
A_BLOCK_TRANSFER.lds_padding,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scalar_per_vector,
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
B_BLOCK_TRANSFER.lds_padding,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
C_BLOCK_TRANSFER.scalar_per_vector,
PIPELINE_SCHEDULER,
PIPELINE_VERSION>;
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,74 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
// This file defines the compile-time "signature" for grouped convolution operations.
// A signature is a collection of properties that fully describe a convolution kernel's
// mathematical characteristics. It uses C++20 concepts and enums to specify these
// properties, enabling compile-time validation and specialization.
//
// The core components of a signature are:
// - Spatial dimensionality (1D, 2D, 3D)
// - Operational direction (Forward, Backward Data, Backward Weight)
// - Tensor memory layout (Channels First/Last)
// - Data type (FP32, FP16, BF16)
// - Fused element-wise operation (e.g., Bias, Clamp)
//
// The file also provides predicate concepts to query the properties of a given
// signature at compile time.
#pragma once
#include <concepts>
#include <type_traits>
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder {
// Constrains convolution to 1D, 2D, or 3D spatial dimensions.
template <auto N>
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
// Constraints for forward convolution layouts.
template <auto LayoutValue, size_t SpatialDim>
concept ValidConvLayoutForSpatialDim =
(SpatialDim == 1 && std::same_as<decltype(LayoutValue), GroupConvLayout1D>) ||
(SpatialDim == 2 && std::same_as<decltype(LayoutValue), GroupConvLayout2D>) ||
(SpatialDim == 3 && std::same_as<decltype(LayoutValue), GroupConvLayout3D>);
// Constrains convolution data types to common floating-point types.
template <DataType T>
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
// Concept for a type that defines a convolution's operational signature.
template <typename T>
concept ConvSignatureDescriptor = requires(T t) {
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
{ t.direction } -> std::convertible_to<ConvDirection>;
requires std::convertible_to<decltype(t.layout), GroupConvLayout1D> ||
std::convertible_to<decltype(t.layout), GroupConvLayout2D> ||
std::convertible_to<decltype(t.layout), GroupConvLayout3D>;
{ t.data_type } -> std::convertible_to<DataType>;
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
};
// Concept to validate a convolution signature's values.
template <auto Sig>
concept ValidConvSignature = requires {
requires ConvSpatialDim<Sig.spatial_dim>;
requires ConvDataType<Sig.data_type>;
};
// Predicate for forward convolution.
template <auto Sig>
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
// Predicate for backward data convolution.
template <auto Sig>
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
// Predicate for backward weight convolution.
template <auto Sig>
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
} // namespace ck_tile::builder

View File

@@ -16,6 +16,7 @@
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck_tile/ops/common/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
@@ -62,7 +63,9 @@ consteval std::string_view type_name()
template <typename T>
constexpr std::string_view layout_name()
{
if constexpr(std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> && requires {
if constexpr((std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> ||
std::is_base_of_v<ck::tensor_layout::BaseTensorLayout, T>) &&
requires {
{ T::name } -> std::convertible_to<std::string_view>;
})
return T::name;

View File

@@ -0,0 +1,90 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace ck_tile::builder {
enum class DataType
{
FP32,
FP16,
BF16,
FP8,
I8,
U8
};
// Memory layouts for 1D convolution tensors.
// G: Group, N: Batch, K: Output Channel, C: Input Channel, W: Width
// Enum defines Input, Weight, and Output tensor layouts respectively.
enum class GroupConvLayout1D
{
GNWC_GKXC_GNWK,
NWGC_GKXC_NWGK,
NGCW_GKXC_NGKW,
NGCW_GKCX_NGKW
};
// Memory layouts for 2D convolution tensors.
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Y: Height, X: Width, H: Height
// Enum defines Input, Weight, and Output tensor layouts respectively.
enum class GroupConvLayout2D
{
GNHWC_GKYXC_GNHWK,
NHWGC_GKYXC_NHWGK,
NGCHW_GKYXC_NGKHW,
NGCHW_GKCYX_NGKHW
};
// Memory layouts for 3D convolution tensors.
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Z: Depth, Y: Height, X: Width, D: Depth,
// H: Height Enum defines Input, Weight, and Output tensor layouts respectively.
enum class GroupConvLayout3D
{
GNDHWC_GKZYXC_GNDHWK,
NDHWGC_GKZYXC_NDHWGK,
NGCDHW_GKZYXC_NGKDHW,
NGCDHW_GKCZYX_NGKDHW,
};
// Direction of the convolution operation.
enum class ConvDirection
{
FORWARD,
BACKWARD_DATA,
BACKWARD_WEIGHT
};
// Fused element-wise operations.
enum class ElementwiseOperation
{
BIAS,
BIAS_CLAMP,
BIAS_BNORM_CLAMP,
BILINEAR,
CLAMP,
SCALE,
PASS_THROUGH
};
// Enums for the current block GEMM pipeline versions.
enum class BlockGemmPipelineVersion
{
V1,
V2,
V3,
V4,
V5
};
// Enums for the forward convolution specialization.
enum class ConvFwdSpecialization
{
DEFAULT,
FILTER_1X1_PAD0,
FILTER_1X1_STRIDE1_PAD0,
FILTER_3x3
};
} // namespace ck_tile::builder

View File

@@ -0,0 +1,18 @@
#pragma once
#include <concepts>
#include <string_view>
#include "ck_tile/builder/builder_utils.hpp"
namespace ck_tile::builder {
static constexpr StringLiteral V0_0_0 = "0.0.0";
static constexpr StringLiteral V0_1_0 = "0.1.0";
static constexpr StringLiteral LATEST_API_VERSION = V0_1_0;
template <StringLiteral V>
concept SupportedVersion = (V == V0_0_0) || (V == V0_1_0);
} // namespace ck_tile::builder

View File

@@ -7,6 +7,7 @@ function(add_ck_builder_test test_name)
target_include_directories(${test_name} PRIVATE
"${PROJECT_SOURCE_DIR}/experimental/builder/include"
"${PROJECT_SOURCE_DIR}/include"
"${CMAKE_CURRENT_SOURCE_DIR}"
)
target_compile_options(${test_name} PRIVATE
-Wno-global-constructors
@@ -24,3 +25,11 @@ add_ck_builder_test(test_get_instance_string
test_get_instance_string.cpp)
add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp)
add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_2d_bf16.cpp
conv/test_ckb_conv_fwd_2d_fp16.cpp
conv/test_ckb_conv_fwd_2d_fp32.cpp
conv/test_ckb_conv_fwd_3d_bf16.cpp
conv/test_ckb_conv_fwd_3d_fp16.cpp
conv/test_ckb_conv_fwd_3d_fp32.cpp)

View File

@@ -0,0 +1,47 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
class FwdConv2DBF16Test : public FwdConvBuilderTestBase
{
};
// 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT
TEST_F(FwdConv2DBF16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
ConvFwdSpecialization::DEFAULT>();
}
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
TEST_F(FwdConv2DBF16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V5,
ConvFwdSpecialization::FILTER_3x3>();
}

View File

@@ -0,0 +1,26 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
class FwdConv2DFP16Test : public FwdConvBuilderTestBase
{
};
TEST_F(FwdConv2DFP16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -0,0 +1,26 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
class FwdConv2DFP32Test : public FwdConvBuilderTestBase
{
};
TEST_F(FwdConv2DFP32Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
{
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
.data_type = DataType::FP32,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
}

View File

@@ -0,0 +1,27 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
class FwdConv3DBF16Test : public FwdConvBuilderTestBase
{
};
// 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT
TEST_F(FwdConv3DBF16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
{
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
.data_type = DataType::BF16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V3,
ConvFwdSpecialization::DEFAULT>();
}

View File

@@ -0,0 +1,27 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
class FwdConv3DFP16Test : public FwdConvBuilderTestBase
{
};
// 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0
TEST_F(FwdConv3DFP16Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
{
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
.data_type = DataType::FP16,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V4,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -0,0 +1,27 @@
#include "utils/ckb_conv_test_common.hpp"
using namespace ck_tile::builder::test_utils;
class FwdConv3DFP32Test : public FwdConvBuilderTestBase
{
};
// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0
TEST_F(FwdConv3DFP32Test,
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
{
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
.spatial_dim = 3,
.direction = ConvDirection::FORWARD,
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
.data_type = DataType::FP32,
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
run_test<FwdConvSignature,
FwdThreadBlock,
BlockGemmPipelineVersion::V1,
ConvFwdSpecialization::FILTER_1X1_PAD0>();
}

View File

@@ -0,0 +1,119 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
namespace ck_tile::builder::test {
namespace ckb = ck_tile::builder;
// Convenience struct for a tuple of m, n, and k values.
template <typename T>
struct MNK
{
T m{};
T n{};
T k{};
};
// Specify thread block dimensions for a GEMM.
struct ThreadBlock
{
// Thread block size.
size_t block_size;
// Size of the submatrix problem in a thread block.
MNK<size_t> tile_size;
};
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
// Describe gridwise GEMM parameters.
struct GridwiseGemm
{
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
size_t ak1 = 0;
size_t bk1 = 0;
size_t m_per_xdl = 0;
size_t n_per_xdl = 0;
size_t m_xdl_per_wave = 0;
size_t n_xdl_per_wave = 0;
};
static_assert(ckb::GridwiseGemmDescriptor<GridwiseGemm>);
// Describe Aand B block transfer thread cluster lengths.
struct BlockTransfer
{
size_t k0;
size_t m_n;
size_t k1;
};
static_assert(ckb::BlockTransferDescriptor<BlockTransfer>);
// Describe C block transfer thread cluster lengths.
struct ThreadCluster
{
size_t m_block;
size_t m_wave_per_xdl;
size_t n_block;
size_t n_wave_per_xdl;
};
static_assert(ThreadClusterDescriptor<ThreadCluster>);
struct LdsTransfer
{
size_t src_vector_dim;
size_t src_scalar_per_vector;
size_t lds_dst_scalar_per_vector;
bool is_direct_load;
bool lds_padding;
};
static_assert(LdsTransferDescriptor<LdsTransfer>);
struct Epilogue
{
size_t m_xdl_per_wave_per_shuffle;
size_t n_xdl_per_wave_per_shuffle;
size_t scalar_per_vector;
};
static_assert(EpilogueDescriptor<Epilogue>);
struct AccessOrder
{
std::array<size_t, 3> order;
};
static_assert(AccessOrderDescriptor<AccessOrder>);
struct BlockTransferABC
{
BlockTransfer block_transfer_a;
BlockTransfer block_transfer_b;
ThreadCluster thread_cluster_dims_c;
LdsTransfer lds_transfer_a;
LdsTransfer lds_transfer_b;
Epilogue epilogue_c;
AccessOrder block_transfer_access_order_a;
AccessOrder block_transfer_access_order_b;
AccessOrder src_access_order_a;
AccessOrder src_access_order_b;
};
struct ConvAlgorithm
{
ThreadBlock thread_block;
GridwiseGemm gridwise_gemm;
BlockTransferABC block_transfer;
BlockGemmPipelineVersion pipeline_version;
ConvFwdSpecialization fwd_specialization;
};
static_assert(ckb::ConvAlgorithmDescriptor<ConvAlgorithm>);
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm>);
static_assert(ckb::SpecifiesGridwiseGemm<ConvAlgorithm>);
static_assert(ckb::SpecifiesBlockTransfer<ConvAlgorithm>);
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm>);
static_assert(ckb::SpecifiesThreadClusterAccessOrder<ConvAlgorithm>);
static_assert(ckb::SpecifiesSourceAccessOrder<ConvAlgorithm>);
static_assert(ckb::SpecifiesGemmPipelineVersion<ConvAlgorithm>);
static_assert(ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm>);
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,23 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/builder/conv_signature_concepts.hpp"
namespace ck_tile::builder::test {
template <typename GroupConvLayout>
struct ConvSignature
{
int spatial_dim;
ConvDirection direction;
GroupConvLayout layout;
DataType data_type;
ElementwiseOperation elementwise_operation;
};
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout1D>>);
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout2D>>);
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout3D>>);
} // namespace ck_tile::builder::test

View File

@@ -0,0 +1,103 @@
#pragma once
#include <gtest/gtest.h>
#include "impl/conv_algorithm_types.hpp"
#include "impl/conv_signature_types.hpp"
#include "ck_tile/builder/conv_builder.hpp"
namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
// Common test base class
class FwdConvBuilderTestBase : public ::testing::Test
{
};
// Common test implementation
template <auto FwdConvSignature,
ThreadBlock FwdThreadBlock,
BlockGemmPipelineVersion FwdPipelineVersion,
ConvFwdSpecialization FwdConvSpecialization>
constexpr void run_test()
{
constexpr GridwiseGemm FwdGemmParams{.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 32,
.n_per_xdl = 32,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4};
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
.thread_cluster_dims_c = {.m_block = 1,
.m_wave_per_xdl = 32,
.n_block = 1,
.n_wave_per_xdl = 8},
.lds_transfer_a = {.src_vector_dim = 2,
.src_scalar_per_vector = 2,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = false},
.lds_transfer_b = {.src_vector_dim = 2,
.src_scalar_per_vector = 8,
.lds_dst_scalar_per_vector = 8,
.is_direct_load = false,
.lds_padding = false},
.epilogue_c = {.m_xdl_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.scalar_per_vector = 8},
.block_transfer_access_order_a = {1, 0, 2},
.block_transfer_access_order_b = {1, 0, 2},
.src_access_order_a = {1, 0, 2},
.src_access_order_b = {1, 0, 2}};
constexpr ConvAlgorithm FwdConvAlgorithm{.thread_block = FwdThreadBlock,
.gridwise_gemm = FwdGemmParams,
.block_transfer = FwdBlockTransfer,
.pipeline_version = FwdPipelineVersion,
.fwd_specialization = FwdConvSpecialization};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
// Verify pipeline version is correct
if(FwdPipelineVersion == BlockGemmPipelineVersion::V1)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5)
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
// Verify specialization is correct
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
const auto invoker_ptr = instance.MakeInvokerPointer();
EXPECT_NE(invoker_ptr, nullptr);
}
// Common thread block configurations
constexpr ThreadBlock DefaultThreadBlock{.block_size = 256,
.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr ThreadBlock SmallThreadBlock{.block_size = 256,
.tile_size = {.m = 128, .n = 128, .k = 32}};
} // namespace ck_tile::builder::test_utils

View File

@@ -552,6 +552,8 @@ struct PassThrough
{
y = type_convert<bf8_t>(x);
}
static constexpr const char* name = "PassThrough";
};
struct UnaryConvert