mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_BUILDER] First fwd convolution builder implementation (#3070)
* Add experimental builder infrastructure for composable_kernel - Add experimental/builder directory with README documentation. - Create initial test infrastructure with CMakeLists.txt and placeholder test. - Update root CMakeLists.txt to support CK_EXPERIMENTAL_BUILDER option. - Update .gitignore to not treat `experimental/builder` as a CMake build directory. This establishes the directory structure for a high-level builder pattern that will provide a semantically-clear interface for constructing CK operations, with initial focus on convolution kernels for MIOpen integration. * Fix clang formatting. * Fix CMake build infrastructure for experimental builder - Add experimental/builder CMakeLists.txt with proper subdirectory structure - Add placeholder include/ck_tile/builder CMakeLists.txt for header installation - Fix gtest.cmake to use include_guard to prevent multiple inclusions - Update root CMakeLists.txt to include full builder directory instead of just tests * Scope C++20 settingto the test code Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Remove redundant GTest::gtest linkage Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Introduce basic types, and convolution algorithm concepts and limits. * Add convolution signature concepts. * Add convolution factory. * Finalize conv factory implementation for fwd convolutions. * Add type definitions for testing. * Add placeholder test. * Add convolution builder definition. * Fully functional fwd conv builder. * Test improvements. * Clean-up include headers. * Enable the limit checks for the convolution algorithm parameters. * Remove dead code. * clang formatting. * Add more tests and missing conv specialization argument. * clang formatting. * Add explicit handling of the tensor layouts. * Add complete 2D/3D layout support to CK Builder - Add missing 2D layouts: GNHWC_GKYXC_GNHWK, NGCHW_GKCYX_NGKHW - Add missing 3D layout: GNDHWC_GKZYXC_GNDHWK - Add 1D layouts (NWGC, NGCW, GNWC, NGCW_GKCX) for future support - Add 3 tests for new 2D/3D layouts - All tests pass (5/5) * Add tests for remaining 2D/3D layouts - Add test for 2D NGCHW_GKYXC_NGKHW (channels-first) with Filter1x1Stride1Pad0 - Add test for 3D NDHWGC_GKZYXC_NDHWGK (channels-last) - All 7 tests pass (complete coverage for all 2D/3D forward layouts) * Change enum converters to consteval. * 7 tests with pipeline and specialization| Test # | Dim | Type | Layout | Pipeline | Specialization | |--------|-----|------|----------------------|----------|-------------------------| | 1 | 2D | BF16 | NHWGC_GKYXC_NHWGK | V1 | DEFAULT | | 2 | 2D | FP16 | GNHWC_GKYXC_GNHWK | V3 | FILTER_1X1_PAD0 | | 3 | 2D | FP32 | NGCHW_GKCYX_NGKHW | V4 | FILTER_1X1_STRIDE1_PAD0 | | 4 | 2D | BF16 | NHWGC_GKYXC_NHWGK | V5 | FILTER_3x3 | | 5 | 3D | FP32 | NGCDHW_GKCZYX_NGKDHW | V1 | FILTER_1X1_PAD0 | | 6 | 3D | BF16 | GNDHWC_GKZYXC_GNDHWK | V3 | DEFAULT | | 7 | 3D | FP16 | NDHWGC_GKZYXC_NDHWGK | V4 | FILTER_1X1_PAD0 | * Add missing convolution layouts and provide better compile-time error in instance traits. * Fix clang formatting. * Changed I8 -> S8. * Fix signature. * Rename concepts and corresponding members. * Rename LDS related parameters. * Remove ODD_C specialization. Add V2 pipeline. * Add missing types. * Add elementwise operation to the conv signature. * Improve compile-time error message for unsupported elementwise ops. * Separate different fwd conv builder tests into separate compilation units. * Fix layout to string and add name to old CK PassThrough elementwise op. * Enable both CK and CK Tile tensor layouts in instance traits. * Fix clang-format. --------- Co-authored-by: John Shumway <jshumway@amd.com> Co-authored-by: John Shumway <john.shumwayjr@gmail.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: JH-Leon-KIM-AMD <jeonghyun.kim@amd.com>
This commit is contained in:
@@ -23,9 +23,18 @@ This project is a prototype for a more general builder pattern for all of compos
|
||||
|
||||
To enable the experimental builder, configure your build with:
|
||||
|
||||
```sh
|
||||
cmake -DCK_EXPERIMENTAL_BUILDER=ON -DCMAKE_CXX_STANDARD=20 ...
|
||||
```bash
|
||||
cmake \
|
||||
-D CMAKE_PREFIX_PATH=/opt/rocm \
|
||||
-D CMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
||||
-D CMAKE_BUILD_TYPE=Release \
|
||||
-D GPU_TARGETS="gfx942;gfx950" \
|
||||
-D CK_EXPERIMENTAL_BUILDER=ON \
|
||||
-D CMAKE_CXX_STANDARD=20 \
|
||||
-G Ninja \
|
||||
..
|
||||
```
|
||||
|
||||
## Building and testing
|
||||
|
||||
During development, build and test from the CK build directory with
|
||||
|
||||
143
experimental/builder/include/ck_tile/builder/builder_utils.hpp
Normal file
143
experimental/builder/include/ck_tile/builder/builder_utils.hpp
Normal file
@@ -0,0 +1,143 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/sequence.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Convert a static array to a sequence
|
||||
// Usage example:
|
||||
// static constexpr std::vector arr {1, 2, 3};
|
||||
// using seq = to_sequence_v<arr>; // seq is ck::Sequence<1, 2, 3>
|
||||
template <typename T, const T& Arr>
|
||||
struct to_sequence_t
|
||||
{
|
||||
private:
|
||||
template <std::size_t... Is>
|
||||
static auto get_sequence_type(std::index_sequence<Is...>) -> ck::Sequence<Arr[Is]...>;
|
||||
|
||||
// Helper method to handler the unusual .Size() method name in ck::Array.
|
||||
static constexpr auto get_size(const auto& arr)
|
||||
{
|
||||
if constexpr(requires { arr.size(); })
|
||||
{
|
||||
return arr.size();
|
||||
}
|
||||
else
|
||||
{
|
||||
return arr.Size();
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
using value = decltype(get_sequence_type(std::make_index_sequence<get_size(Arr)>{}));
|
||||
};
|
||||
|
||||
template <auto& Arr>
|
||||
using to_sequence_v = typename to_sequence_t<std::remove_cvref_t<decltype(Arr)>, Arr>::value;
|
||||
|
||||
// Wrapper function to make constexpr strings a structural type for NTTP.
|
||||
template <size_t N>
|
||||
struct StringLiteral
|
||||
{
|
||||
char data[N];
|
||||
constexpr StringLiteral(const char (&str)[N])
|
||||
{
|
||||
for(size_t i = 0; i < N; ++i)
|
||||
data[i] = str[i];
|
||||
}
|
||||
|
||||
constexpr bool operator==(const StringLiteral<N>& other) const
|
||||
{
|
||||
for(size_t i = 0; i < N; ++i)
|
||||
{
|
||||
if(data[i] != other.data[i])
|
||||
{
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
// This is a C++17 deduction guide. It allows the compiler to automatically
|
||||
// deduce the template argument `N` for `StringLiteral` from a string literal
|
||||
// constructor argument. For example, you can write `StringLiteral s{"foo"};`
|
||||
// instead of `StringLiteral<4> s{"foo"};`.
|
||||
template <size_t N>
|
||||
StringLiteral(const char (&)[N]) -> StringLiteral<N>;
|
||||
|
||||
// Helper to provide a readable error for unsupported enum values.
|
||||
// The compiler will print the name of this struct in the error message, so
|
||||
// the name of the enum value will appear instead of just its integer value.
|
||||
template <auto T>
|
||||
struct UnsupportedEnumValue
|
||||
{
|
||||
};
|
||||
|
||||
// Helper functions to convert enums to strings
|
||||
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
|
||||
{
|
||||
switch(dir)
|
||||
{
|
||||
case ConvDirection::FORWARD: return "Forward";
|
||||
case ConvDirection::BACKWARD_DATA: return "Backward Data";
|
||||
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view DataTypeToString(DataType dt)
|
||||
{
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::FP16: return "FP16";
|
||||
case DataType::FP32: return "FP32";
|
||||
case DataType::BF16: return "BF16";
|
||||
case DataType::FP8: return "FP8";
|
||||
case DataType::I8: return "I8";
|
||||
case DataType::U8: return "U8";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout1D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout1D::GNWC_GKXC_GNWK: return "GNWC_GKXC_GNWK";
|
||||
case GroupConvLayout1D::NWGC_GKXC_NWGK: return "NWGC_GKXC_NWGK";
|
||||
case GroupConvLayout1D::NGCW_GKXC_NGKW: return "NGCW_GKXC_NGKW";
|
||||
case GroupConvLayout1D::NGCW_GKCX_NGKW: return "NGCW_GKCX_NGKW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout2D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout2D::GNHWC_GKYXC_GNHWK: return "GNHWC_GKYXC_GNHWK";
|
||||
case GroupConvLayout2D::NHWGC_GKYXC_NHWGK: return "NHWGC_GKYXC_NHWGK";
|
||||
case GroupConvLayout2D::NGCHW_GKYXC_NGKHW: return "NGCHW_GKYXC_NGKHW";
|
||||
case GroupConvLayout2D::NGCHW_GKCYX_NGKHW: return "NGCHW_GKCYX_NGKHW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout3D layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK: return "GNDHWC_GKZYXC_GNDHWK";
|
||||
case GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK: return "NDHWGC_GKZYXC_NDHWGK";
|
||||
case GroupConvLayout3D::NGCDHW_GKZYXC_NGKDHW: return "NGCDHW_GKZYXC_NGKDHW";
|
||||
case GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW: return "NGCDHW_GKCZYX_NGKDHW";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,141 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
#include <array>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
/********************************************************************/
|
||||
/* Descriptors for individual elements of the algorithm description */
|
||||
/********************************************************************/
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem.
|
||||
template <typename T>
|
||||
concept ThreadBlockDescriptor = requires(T t) {
|
||||
{ t.block_size } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.m } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.n } -> std::convertible_to<size_t>;
|
||||
{ t.tile_size.k } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for parameters that describe a gridwise GEMM problem.
|
||||
template <typename T>
|
||||
concept GridwiseGemmDescriptor = requires(T t) {
|
||||
{ t.ak1 } -> std::convertible_to<size_t>;
|
||||
{ t.bk1 } -> std::convertible_to<size_t>;
|
||||
{ t.m_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.n_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.m_xdl_per_wave } -> std::convertible_to<size_t>;
|
||||
{ t.n_xdl_per_wave } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for vectorized data transfer for convolution input tensors.
|
||||
template <typename T>
|
||||
concept BlockTransferDescriptor = requires(T t) {
|
||||
{ t.k0 } -> std::convertible_to<size_t>;
|
||||
{ t.m_n } -> std::convertible_to<size_t>;
|
||||
{ t.k1 } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for thread cluster dimensions for GEMM output tensor.
|
||||
template <typename T>
|
||||
concept ThreadClusterDescriptor = requires(T t) {
|
||||
{ t.m_block } -> std::convertible_to<size_t>;
|
||||
{ t.m_wave_per_xdl } -> std::convertible_to<size_t>;
|
||||
{ t.n_block } -> std::convertible_to<size_t>;
|
||||
{ t.n_wave_per_xdl } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for the LDS transfer for the convolution input tensors.
|
||||
template <typename T>
|
||||
concept LdsTransferDescriptor = requires(T t) {
|
||||
{ t.src_vector_dim } -> std::convertible_to<size_t>;
|
||||
{ t.src_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.lds_dst_scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
{ t.is_direct_load } -> std::convertible_to<bool>;
|
||||
{ t.lds_padding } -> std::convertible_to<bool>;
|
||||
};
|
||||
|
||||
// Concept for the convolution output tensor epilogue (copy from registers to global memory via
|
||||
// LDS).
|
||||
template <typename T>
|
||||
concept EpilogueDescriptor = requires(T t) {
|
||||
{ t.m_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.n_xdl_per_wave_per_shuffle } -> std::convertible_to<size_t>;
|
||||
{ t.scalar_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
// Concept for the thread cluster access order
|
||||
template <typename T>
|
||||
concept AccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
|
||||
};
|
||||
|
||||
// No requirements yet for a ConvAlogorithm concept.
|
||||
template <typename T>
|
||||
concept ConvAlgorithmDescriptor = std::is_class_v<T>;
|
||||
|
||||
/******************************************** */
|
||||
/* Requirements for the algorithm description */
|
||||
/******************************************** */
|
||||
|
||||
// Concept to check if struct specifies thread block info.
|
||||
template <typename T>
|
||||
concept SpecifiesThreadBlock = requires {
|
||||
{ T::thread_block } -> ThreadBlockDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies gridwise GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesGridwiseGemm = requires {
|
||||
{ T::gridwise_gemm } -> GridwiseGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution input and output block transfer info.
|
||||
template <typename T>
|
||||
concept SpecifiesBlockTransfer = requires(T t) {
|
||||
{ T::block_transfer.block_transfer_a } -> BlockTransferDescriptor;
|
||||
{ T::block_transfer.block_transfer_b } -> BlockTransferDescriptor;
|
||||
{ T::block_transfer.thread_cluster_dims_c } -> ThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
|
||||
template <typename T>
|
||||
concept SpecifiesLdsTransfer = requires(T t) {
|
||||
{ T::block_transfer.lds_transfer_a } -> LdsTransferDescriptor;
|
||||
{ T::block_transfer.lds_transfer_b } -> LdsTransferDescriptor;
|
||||
{ T::block_transfer.epilogue_c } -> EpilogueDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies thread cluster access order info.
|
||||
template <typename T>
|
||||
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
|
||||
{ T::block_transfer.block_transfer_access_order_a } -> AccessOrderDescriptor;
|
||||
{ T::block_transfer.block_transfer_access_order_b } -> AccessOrderDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies source access order info.
|
||||
template <typename T>
|
||||
concept SpecifiesSourceAccessOrder = requires(T t) {
|
||||
{ T::block_transfer.src_access_order_a } -> AccessOrderDescriptor;
|
||||
{ T::block_transfer.src_access_order_b } -> AccessOrderDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block_gemm_pipeline_version.
|
||||
template <typename T>
|
||||
concept SpecifiesGemmPipelineVersion = requires {
|
||||
{ T::pipeline_version } -> std::convertible_to<BlockGemmPipelineVersion>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesFwdConcSpecialization = requires {
|
||||
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Limits for input vector transfer.
|
||||
template <auto Value>
|
||||
concept InputVectorTransferLimits = requires {
|
||||
requires Value.src_vector_dim > 0 && Value.src_scalar_per_vector > 0 &&
|
||||
Value.lds_dst_scalar_per_vector > 0;
|
||||
};
|
||||
|
||||
// Limits for output vector transfer.
|
||||
template <auto Value>
|
||||
concept OutputVectorTransferLimits = requires {
|
||||
requires Value.scalar_per_vector > 0 && Value.m_xdl_per_wave_per_shuffle > 0 &&
|
||||
Value.n_xdl_per_wave_per_shuffle > 0;
|
||||
};
|
||||
|
||||
// Limits for access order. Must be a permutation of {0, 1, 2}.
|
||||
template <auto Value>
|
||||
concept AccessOrderLimits = requires {
|
||||
requires((Value[0] != Value[1]) && (Value[0] != Value[2]) && (Value[1] != Value[2]) &&
|
||||
(Value[0] >= 0 && Value[0] < 3) && (Value[1] >= 0 && Value[1] < 3) &&
|
||||
(Value[2] >= 0 && Value[2] < 3));
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,38 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/conv_factory.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
/**
|
||||
* @brief Top-level builder for creating convolution kernel instances.
|
||||
*
|
||||
* This struct serves as the main entry point for generating a convolution kernel.
|
||||
* It uses a factory pattern based on the provided signature, algorithm, and version
|
||||
* to construct the appropriate kernel instance.
|
||||
*
|
||||
* @tparam SIGNATURE The convolution signature, which describes the mathematical functionality of
|
||||
* the algorithm (e.g., data types, layouts, direction).
|
||||
* @tparam ALGORITHM The specific convolution algorithm to be used for the implementation.
|
||||
* @tparam VERSION The version of the builder implementation.
|
||||
*/
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION = LATEST_API_VERSION>
|
||||
requires SupportedVersion<VERSION> && ValidConvSignature<SIGNATURE>
|
||||
struct ConvBuilder
|
||||
{
|
||||
static constexpr auto kVersion = VERSION;
|
||||
using Factory = ConvFactory<SIGNATURE, ALGORITHM, VERSION>;
|
||||
// Output: The kernel class.
|
||||
using Instance = Factory::Instance;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
539
experimental/builder/include/ck_tile/builder/conv_factory.hpp
Normal file
539
experimental/builder/include/ck_tile/builder/conv_factory.hpp
Normal file
@@ -0,0 +1,539 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// A factory for instantiating CK convolution kernels.
|
||||
//
|
||||
// This file translates a semantic description of a convolution operation
|
||||
// (`ConvSignatureDescriptor` and `ConvAlgorithmDescriptor`) into specific,
|
||||
// low-level template arguments required by the underlying CK device-level
|
||||
// kernel implementations. This abstraction enables more complex build
|
||||
// time logic and simplifies the kernel specification.
|
||||
//
|
||||
// Key Components:
|
||||
//
|
||||
// Template Metaprogram:
|
||||
// - ConvFactory: The main factory, with specializations for different
|
||||
// convolution directions (currently only forward).
|
||||
//
|
||||
// Template Metaprogram Helpers:
|
||||
// - ConvTensorLayouts: Maps layout enums to CK layout types for different
|
||||
// spatial dimensions (2D/3D) and directions.
|
||||
// - ConvTensorTypes: Maps data type enums (FP16, BF16, FP32) to C++ types used by CK.
|
||||
// - ConvPassThroughOps: Hard-coded pass-through element-wise operations.
|
||||
// - ConvSpec: Encapsulates convolution and GEMM specialization enums.
|
||||
//
|
||||
// `constexpr` Helper Functions:
|
||||
// - SetThreadBlockInfo: Determines thread block dimensions and tile sizes.
|
||||
// - SetConvTuningInfo: Sets XDL and AK1/BK1 tuning parameters.
|
||||
// - SetFwdConvABlockTransfer: Configures A tensor block transfer parameters.
|
||||
// - SetFwdConvBBlockTransfer: Configures B tensor block transfer parameters.
|
||||
// - SetCBlockTransfer: Configures C tensor block transfer parameters.
|
||||
// - SetBlockGemmPipelineVersion: Maps pipeline version enum to CK types.
|
||||
//
|
||||
// The primary entry point is the `ConvFactory` struct, which is currently
|
||||
// specialized for forward convolutions and produces instances of
|
||||
// DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp"
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
#include "ck_tile/builder/conv_algorithm_limits.hpp"
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
#include "ck_tile/builder/versions.hpp"
|
||||
|
||||
namespace ck_tile::builder::factory_internal {
|
||||
|
||||
// Type mappings from the builder FwdGroupConvLayout enum classes to the CK tensor data types.
|
||||
template <auto LayoutValue, size_t SPATIAL_DIM, ConvDirection DIR>
|
||||
requires(ConvSpatialDim<SPATIAL_DIM> && ValidConvLayoutForSpatialDim<LayoutValue, SPATIAL_DIM>)
|
||||
struct ConvTensorLayouts
|
||||
{
|
||||
// This will trigger if a specialization for the given layout is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
using Layout = decltype(LayoutValue);
|
||||
static_assert(sizeof(Layout) == 0,
|
||||
"Internal error. Unsupported layout for convolution factory.");
|
||||
};
|
||||
|
||||
// 1D Forward Convolution Layout Specializations
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NWGC_GKXC_NWGK, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKXC_NGKW, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::GNWC_GKXC_GNWK, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNWK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout1D::NGCW_GKCX_NGKW, 1, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKYXC_NGKHW, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NHWGC_GKYXC_NHWGK, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NHWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NHWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::GNHWC_GKYXC_GNHWK, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNHWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNHWK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout2D::NGCHW_GKCYX_NGKHW, 2, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCYX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NGCDHW;
|
||||
using BLayout = ck::tensor_layout::convolution::GKCZYX;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NGKDHW;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::NDHWGC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::NDHWGK;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorLayouts<GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK, 3, ConvDirection::FORWARD>
|
||||
{
|
||||
using ALayout = ck::tensor_layout::convolution::GNDHWC;
|
||||
using BLayout = ck::tensor_layout::convolution::GKZYXC;
|
||||
using DsLayout = ck::Tuple<>;
|
||||
using ELayout = ck::tensor_layout::convolution::GNDHWK;
|
||||
};
|
||||
|
||||
// Type mappings from builder convolution data type to CK tensor types.
|
||||
template <DataType T>
|
||||
struct ConvTensorTypes
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported data type for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP16>
|
||||
{
|
||||
using ADataType = ck::half_t;
|
||||
using BDataType = ck::half_t;
|
||||
using CShuffleDataType = ck::half_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::half_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::BF16>
|
||||
{
|
||||
using ADataType = ck::bhalf_t;
|
||||
using BDataType = ck::bhalf_t;
|
||||
using CShuffleDataType = ck::bhalf_t;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = ck::bhalf_t;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvTensorTypes<DataType::FP32>
|
||||
{
|
||||
using ADataType = float;
|
||||
using BDataType = float;
|
||||
using CShuffleDataType = float;
|
||||
using DsDataTypes = ck::Tuple<>;
|
||||
using AccDataType = float;
|
||||
using EDataType = float;
|
||||
};
|
||||
|
||||
template <ElementwiseOperation T>
|
||||
struct ElementwiseOps
|
||||
{
|
||||
// This will trigger if a specialization for the given DataType is not found.
|
||||
// We should always catch this in an earlier validation check.
|
||||
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
|
||||
"Internal error. Unsupported elementwise operation for convolution factory.");
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ElementwiseOps<ElementwiseOperation::PASS_THROUGH>
|
||||
{
|
||||
using AElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using BElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
using CDEElementwiseOp = ck::tensor_operation::element_wise::PassThrough;
|
||||
};
|
||||
|
||||
// The algorithm specializations for the convolution and GEMM.
|
||||
template <typename CONV_ENUM>
|
||||
requires(
|
||||
std::is_same_v<CONV_ENUM, ck::tensor_operation::device::ConvolutionForwardSpecialization>)
|
||||
struct ConvSpec
|
||||
{
|
||||
CONV_ENUM conv_spec;
|
||||
ck::tensor_operation::device::GemmSpecialization gemm_spec;
|
||||
};
|
||||
|
||||
// Deduction guide for ConvSpec to simplify brace initialization.
|
||||
template <typename CONV_ENUM, typename GEMM_ENUM>
|
||||
ConvSpec(CONV_ENUM, GEMM_ENUM) -> ConvSpec<CONV_ENUM>;
|
||||
|
||||
// Block info for a convolution.
|
||||
struct MNK
|
||||
{
|
||||
size_t m{};
|
||||
size_t n{};
|
||||
size_t k{};
|
||||
};
|
||||
struct ConvBlock
|
||||
{
|
||||
size_t block_size = 0;
|
||||
MNK per_block = {};
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr ConvBlock SetThreadBlockInfo()
|
||||
{
|
||||
constexpr auto& TB = ALGORITHM.thread_block;
|
||||
return ConvBlock{.block_size = TB.block_size,
|
||||
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k}};
|
||||
}
|
||||
|
||||
// Convolution tuning parameters.
|
||||
struct GridwiseGemm
|
||||
{
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
size_t m_per_xdl = 0;
|
||||
size_t n_per_xdl = 0;
|
||||
size_t m_xdl_per_wave = 0;
|
||||
size_t n_xdl_per_wave = 0;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr GridwiseGemm SetGridwiseGemmInfo()
|
||||
{
|
||||
constexpr auto& TP = ALGORITHM.gridwise_gemm;
|
||||
return GridwiseGemm{
|
||||
.ak1 = TP.ak1,
|
||||
.bk1 = TP.bk1,
|
||||
.m_per_xdl = TP.m_per_xdl,
|
||||
.n_per_xdl = TP.n_per_xdl,
|
||||
.m_xdl_per_wave = TP.m_xdl_per_wave,
|
||||
.n_xdl_per_wave = TP.n_xdl_per_wave,
|
||||
};
|
||||
}
|
||||
|
||||
// Block transfer parameters for A or B tensor.
|
||||
struct BlockTransfer
|
||||
{
|
||||
ck::Array<size_t, 3> thread_cluster_dims = {0, 0, 0}; // k0, m, k1
|
||||
ck::Array<size_t, 3> thread_cluster_order = {0, 0, 0};
|
||||
ck::Array<size_t, 3> src_access_order = {0, 0, 0};
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
bool is_direct_load = false;
|
||||
bool lds_padding = false;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockTransfer SetFwdConvABlockTransfer()
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_a;
|
||||
constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_a;
|
||||
constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_a;
|
||||
constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_a;
|
||||
|
||||
BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1},
|
||||
.thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]},
|
||||
.src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]},
|
||||
.src_vector_dim = LDS.src_vector_dim,
|
||||
.src_scalar_per_vector = LDS.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = LDS.is_direct_load,
|
||||
.lds_padding = LDS.lds_padding};
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockTransfer SetFwdConvBBlockTransfer()
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_b;
|
||||
constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_b;
|
||||
constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_b;
|
||||
constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_b;
|
||||
|
||||
BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1},
|
||||
.thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]},
|
||||
.src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]},
|
||||
.src_vector_dim = LDS.src_vector_dim,
|
||||
.src_scalar_per_vector = LDS.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = LDS.is_direct_load,
|
||||
.lds_padding = LDS.lds_padding};
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
// Block transfer parameters for C tensor.
|
||||
struct CBlockTransfer
|
||||
{
|
||||
size_t m_xdl_per_wave_per_shuffle = 0;
|
||||
size_t n_xdl_per_wave_per_shuffle = 0;
|
||||
ck::Array<size_t, 4> thread_cluster_dims = {0, 0, 0, 0};
|
||||
size_t scalar_per_vector = 0;
|
||||
};
|
||||
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr CBlockTransfer SetCBlockTransfer()
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_c;
|
||||
constexpr auto& EPC = ALGORITHM.block_transfer.epilogue_c;
|
||||
CBlockTransfer block_transfer{.m_xdl_per_wave_per_shuffle = EPC.m_xdl_per_wave_per_shuffle,
|
||||
.n_xdl_per_wave_per_shuffle = EPC.n_xdl_per_wave_per_shuffle,
|
||||
.thread_cluster_dims =
|
||||
{
|
||||
TCL.m_block,
|
||||
TCL.m_wave_per_xdl,
|
||||
TCL.n_block,
|
||||
TCL.n_wave_per_xdl,
|
||||
},
|
||||
.scalar_per_vector = EPC.scalar_per_vector};
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
{
|
||||
constexpr auto version = ALGORITHM.pipeline_version;
|
||||
|
||||
if constexpr(version == BlockGemmPipelineVersion::V1)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v1;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V3)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v3;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V4)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
else if constexpr(version == BlockGemmPipelineVersion::V5)
|
||||
{
|
||||
return ck::BlockGemmPipelineVersion::v5;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown BlockGemmPipelineVersion");
|
||||
}
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval ck::tensor_operation::device::ConvolutionForwardSpecialization SetFwdConvSpecialization()
|
||||
{
|
||||
constexpr auto specialization = ALGORITHM.fwd_specialization;
|
||||
|
||||
if constexpr(specialization == ConvFwdSpecialization::DEFAULT)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
}
|
||||
else if constexpr(specialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
{
|
||||
return ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter3x3;
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert(false, "Unknown ConvFwdSpecialization");
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory_internal
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Primary template for the convolution factory.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
auto VERSION>
|
||||
struct ConvFactory;
|
||||
|
||||
// Factory specialization for an instance of a grouped forward convolution kernel.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
StringLiteral VERSION>
|
||||
requires ConvDirectionIsForward<SIGNATURE>
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts =
|
||||
factory_internal::ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ElementwiseOps<SIGNATURE.elementwise_operation>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
// Check preconditions for the algorithm description.
|
||||
static_assert(SPATIAL_DIM == 2 || SPATIAL_DIM == 3,
|
||||
"Only 2D and 3D convolutions are supported in this factory.");
|
||||
static_assert(SpecifiesThreadBlock<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread block info.");
|
||||
static_assert(SpecifiesGridwiseGemm<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify gridwise GEMM info.");
|
||||
static_assert(SpecifiesBlockTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block transfer info.");
|
||||
static_assert(SpecifiesLdsTransfer<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify LDS transfer info.");
|
||||
static_assert(
|
||||
SpecifiesThreadClusterAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify thread cluster access order info.");
|
||||
static_assert(SpecifiesSourceAccessOrder<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify source access order info.");
|
||||
static_assert(SpecifiesGemmPipelineVersion<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify block gemm pipeline version.");
|
||||
static_assert(SpecifiesFwdConcSpecialization<AlgorithmType>,
|
||||
"The convolution algorithm descriptor must specify forward convolution "
|
||||
"specialization.");
|
||||
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr factory_internal::ConvSpec SPECIALIZATION{
|
||||
.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
};
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM =
|
||||
factory_internal::SetGridwiseGemmInfo<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PIPELINE_VERSION =
|
||||
factory_internal::SetBlockGemmPipelineVersion<ALGORITHM>();
|
||||
|
||||
// Check limits for the algorithm parameters.
|
||||
// TODO: Add more limits checks as needed.
|
||||
static_assert(InputVectorTransferLimits<A_BLOCK_TRANSFER>);
|
||||
static_assert(InputVectorTransferLimits<B_BLOCK_TRANSFER>);
|
||||
static_assert(OutputVectorTransferLimits<C_BLOCK_TRANSFER>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.thread_cluster_order>);
|
||||
static_assert(AccessOrderLimits<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataType,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
GRIDWISE_GEMM.m_per_xdl,
|
||||
GRIDWISE_GEMM.n_per_xdl,
|
||||
GRIDWISE_GEMM.m_xdl_per_wave,
|
||||
GRIDWISE_GEMM.n_xdl_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
A_BLOCK_TRANSFER.lds_padding,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_dst_scalar_per_vector,
|
||||
B_BLOCK_TRANSFER.lds_padding,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
to_sequence_v<C_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
C_BLOCK_TRANSFER.scalar_per_vector,
|
||||
PIPELINE_SCHEDULER,
|
||||
PIPELINE_VERSION>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -0,0 +1,74 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
// This file defines the compile-time "signature" for grouped convolution operations.
|
||||
// A signature is a collection of properties that fully describe a convolution kernel's
|
||||
// mathematical characteristics. It uses C++20 concepts and enums to specify these
|
||||
// properties, enabling compile-time validation and specialization.
|
||||
//
|
||||
// The core components of a signature are:
|
||||
// - Spatial dimensionality (1D, 2D, 3D)
|
||||
// - Operational direction (Forward, Backward Data, Backward Weight)
|
||||
// - Tensor memory layout (Channels First/Last)
|
||||
// - Data type (FP32, FP16, BF16)
|
||||
// - Fused element-wise operation (e.g., Bias, Clamp)
|
||||
//
|
||||
// The file also provides predicate concepts to query the properties of a given
|
||||
// signature at compile time.
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include "ck_tile/builder/types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Constrains convolution to 1D, 2D, or 3D spatial dimensions.
|
||||
template <auto N>
|
||||
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
|
||||
|
||||
// Constraints for forward convolution layouts.
|
||||
template <auto LayoutValue, size_t SpatialDim>
|
||||
concept ValidConvLayoutForSpatialDim =
|
||||
(SpatialDim == 1 && std::same_as<decltype(LayoutValue), GroupConvLayout1D>) ||
|
||||
(SpatialDim == 2 && std::same_as<decltype(LayoutValue), GroupConvLayout2D>) ||
|
||||
(SpatialDim == 3 && std::same_as<decltype(LayoutValue), GroupConvLayout3D>);
|
||||
|
||||
// Constrains convolution data types to common floating-point types.
|
||||
template <DataType T>
|
||||
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16) ||
|
||||
(T == DataType::FP8) || (T == DataType::I8) || (T == DataType::U8);
|
||||
|
||||
// Concept for a type that defines a convolution's operational signature.
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
{ t.spatial_dim } -> std::convertible_to<unsigned int>;
|
||||
{ t.direction } -> std::convertible_to<ConvDirection>;
|
||||
requires std::convertible_to<decltype(t.layout), GroupConvLayout1D> ||
|
||||
std::convertible_to<decltype(t.layout), GroupConvLayout2D> ||
|
||||
std::convertible_to<decltype(t.layout), GroupConvLayout3D>;
|
||||
{ t.data_type } -> std::convertible_to<DataType>;
|
||||
{ t.elementwise_operation } -> std::convertible_to<ElementwiseOperation>;
|
||||
};
|
||||
|
||||
// Concept to validate a convolution signature's values.
|
||||
template <auto Sig>
|
||||
concept ValidConvSignature = requires {
|
||||
requires ConvSpatialDim<Sig.spatial_dim>;
|
||||
requires ConvDataType<Sig.data_type>;
|
||||
};
|
||||
|
||||
// Predicate for forward convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsForward = (Sig.direction == ConvDirection::FORWARD);
|
||||
|
||||
// Predicate for backward data convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_DATA);
|
||||
|
||||
// Predicate for backward weight convolution.
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -16,6 +16,7 @@
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck_tile/ops/common/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
@@ -62,7 +63,9 @@ consteval std::string_view type_name()
|
||||
template <typename T>
|
||||
constexpr std::string_view layout_name()
|
||||
{
|
||||
if constexpr(std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> && requires {
|
||||
if constexpr((std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> ||
|
||||
std::is_base_of_v<ck::tensor_layout::BaseTensorLayout, T>) &&
|
||||
requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
})
|
||||
return T::name;
|
||||
|
||||
90
experimental/builder/include/ck_tile/builder/types.hpp
Normal file
90
experimental/builder/include/ck_tile/builder/types.hpp
Normal file
@@ -0,0 +1,90 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
enum class DataType
|
||||
{
|
||||
FP32,
|
||||
FP16,
|
||||
BF16,
|
||||
FP8,
|
||||
I8,
|
||||
U8
|
||||
};
|
||||
|
||||
// Memory layouts for 1D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, W: Width
|
||||
// Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout1D
|
||||
{
|
||||
GNWC_GKXC_GNWK,
|
||||
NWGC_GKXC_NWGK,
|
||||
NGCW_GKXC_NGKW,
|
||||
NGCW_GKCX_NGKW
|
||||
};
|
||||
|
||||
// Memory layouts for 2D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Y: Height, X: Width, H: Height
|
||||
// Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout2D
|
||||
{
|
||||
GNHWC_GKYXC_GNHWK,
|
||||
NHWGC_GKYXC_NHWGK,
|
||||
NGCHW_GKYXC_NGKHW,
|
||||
NGCHW_GKCYX_NGKHW
|
||||
};
|
||||
|
||||
// Memory layouts for 3D convolution tensors.
|
||||
// G: Group, N: Batch, K: Output Channel, C: Input Channel, Z: Depth, Y: Height, X: Width, D: Depth,
|
||||
// H: Height Enum defines Input, Weight, and Output tensor layouts respectively.
|
||||
enum class GroupConvLayout3D
|
||||
{
|
||||
GNDHWC_GKZYXC_GNDHWK,
|
||||
NDHWGC_GKZYXC_NDHWGK,
|
||||
NGCDHW_GKZYXC_NGKDHW,
|
||||
NGCDHW_GKCZYX_NGKDHW,
|
||||
};
|
||||
|
||||
// Direction of the convolution operation.
|
||||
enum class ConvDirection
|
||||
{
|
||||
FORWARD,
|
||||
BACKWARD_DATA,
|
||||
BACKWARD_WEIGHT
|
||||
};
|
||||
|
||||
// Fused element-wise operations.
|
||||
enum class ElementwiseOperation
|
||||
{
|
||||
BIAS,
|
||||
BIAS_CLAMP,
|
||||
BIAS_BNORM_CLAMP,
|
||||
BILINEAR,
|
||||
CLAMP,
|
||||
SCALE,
|
||||
PASS_THROUGH
|
||||
};
|
||||
|
||||
// Enums for the current block GEMM pipeline versions.
|
||||
enum class BlockGemmPipelineVersion
|
||||
{
|
||||
V1,
|
||||
V2,
|
||||
V3,
|
||||
V4,
|
||||
V5
|
||||
};
|
||||
|
||||
// Enums for the forward convolution specialization.
|
||||
enum class ConvFwdSpecialization
|
||||
{
|
||||
DEFAULT,
|
||||
FILTER_1X1_PAD0,
|
||||
FILTER_1X1_STRIDE1_PAD0,
|
||||
FILTER_3x3
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
18
experimental/builder/include/ck_tile/builder/versions.hpp
Normal file
18
experimental/builder/include/ck_tile/builder/versions.hpp
Normal file
@@ -0,0 +1,18 @@
|
||||
#pragma once
|
||||
|
||||
#include <concepts>
|
||||
#include <string_view>
|
||||
|
||||
#include "ck_tile/builder/builder_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
static constexpr StringLiteral V0_0_0 = "0.0.0";
|
||||
static constexpr StringLiteral V0_1_0 = "0.1.0";
|
||||
|
||||
static constexpr StringLiteral LATEST_API_VERSION = V0_1_0;
|
||||
|
||||
template <StringLiteral V>
|
||||
concept SupportedVersion = (V == V0_0_0) || (V == V0_1_0);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -7,6 +7,7 @@ function(add_ck_builder_test test_name)
|
||||
target_include_directories(${test_name} PRIVATE
|
||||
"${PROJECT_SOURCE_DIR}/experimental/builder/include"
|
||||
"${PROJECT_SOURCE_DIR}/include"
|
||||
"${CMAKE_CURRENT_SOURCE_DIR}"
|
||||
)
|
||||
target_compile_options(${test_name} PRIVATE
|
||||
-Wno-global-constructors
|
||||
@@ -24,3 +25,11 @@ add_ck_builder_test(test_get_instance_string
|
||||
test_get_instance_string.cpp)
|
||||
|
||||
add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp)
|
||||
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp)
|
||||
47
experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp
Normal file
47
experimental/builder/test/conv/test_ckb_conv_fwd_2d_bf16.cpp
Normal file
@@ -0,0 +1,47 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DBF16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V1 and DEFAULT
|
||||
TEST_F(FwdConv2DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
|
||||
TEST_F(FwdConv2DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_BF16_NHWGC_Filter3x3)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NHWGC_GKYXC_NHWGK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V5,
|
||||
ConvFwdSpecialization::FILTER_3x3>();
|
||||
}
|
||||
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp
Normal file
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp16.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DFP16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(FwdConv2DFP16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP16_GNHWC)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp
Normal file
26
experimental/builder/test/conv/test_ckb_conv_fwd_2d_fp32.cpp
Normal file
@@ -0,0 +1,26 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv2DFP32Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
TEST_F(FwdConv2DFP32Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_2D_FP32_NGCHW_GKCYX)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout2D> FwdConvSignature{
|
||||
.spatial_dim = 2,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout2D::NGCHW_GKCYX_NGKHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0>();
|
||||
}
|
||||
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp
Normal file
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_bf16.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DBF16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 3D BF16 GNDHWC (group-first, channels-last) with Pipeline V3 and DEFAULT
|
||||
TEST_F(FwdConv3DBF16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_BF16_GNDHWC)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::GNDHWC_GKZYXC_GNDHWK,
|
||||
.data_type = DataType::BF16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V3,
|
||||
ConvFwdSpecialization::DEFAULT>();
|
||||
}
|
||||
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp
Normal file
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp16.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DFP16Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 3D FP16 NDHWGC (channels-last) with Pipeline V4 and FILTER_1X1_PAD0
|
||||
TEST_F(FwdConv3DFP16Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP16_NDHWGC_ChannelsLast)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NDHWGC_GKZYXC_NDHWGK,
|
||||
.data_type = DataType::FP16,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V4,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp
Normal file
27
experimental/builder/test/conv/test_ckb_conv_fwd_3d_fp32.cpp
Normal file
@@ -0,0 +1,27 @@
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
|
||||
using namespace ck_tile::builder::test_utils;
|
||||
|
||||
class FwdConv3DFP32Test : public FwdConvBuilderTestBase
|
||||
{
|
||||
};
|
||||
|
||||
// 3D FP32 NGCDHW (channels-first) with Pipeline V1 and FILTER_1X1_PAD0
|
||||
TEST_F(FwdConv3DFP32Test,
|
||||
Create_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3_Instance_3D_FP32_ChannelsFirst)
|
||||
{
|
||||
constexpr ConvSignature<GroupConvLayout3D> FwdConvSignature{
|
||||
.spatial_dim = 3,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout3D::NGCDHW_GKCZYX_NGKDHW,
|
||||
.data_type = DataType::FP32,
|
||||
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
|
||||
|
||||
constexpr ThreadBlock FwdThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
run_test<FwdConvSignature,
|
||||
FwdThreadBlock,
|
||||
BlockGemmPipelineVersion::V1,
|
||||
ConvFwdSpecialization::FILTER_1X1_PAD0>();
|
||||
}
|
||||
119
experimental/builder/test/impl/conv_algorithm_types.hpp
Normal file
119
experimental/builder/test/impl/conv_algorithm_types.hpp
Normal file
@@ -0,0 +1,119 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
namespace ckb = ck_tile::builder;
|
||||
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
template <typename T>
|
||||
struct MNK
|
||||
{
|
||||
T m{};
|
||||
T n{};
|
||||
T k{};
|
||||
};
|
||||
|
||||
// Specify thread block dimensions for a GEMM.
|
||||
struct ThreadBlock
|
||||
{
|
||||
// Thread block size.
|
||||
size_t block_size;
|
||||
// Size of the submatrix problem in a thread block.
|
||||
MNK<size_t> tile_size;
|
||||
};
|
||||
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
|
||||
|
||||
// Describe gridwise GEMM parameters.
|
||||
struct GridwiseGemm
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
size_t ak1 = 0;
|
||||
size_t bk1 = 0;
|
||||
size_t m_per_xdl = 0;
|
||||
size_t n_per_xdl = 0;
|
||||
size_t m_xdl_per_wave = 0;
|
||||
size_t n_xdl_per_wave = 0;
|
||||
};
|
||||
static_assert(ckb::GridwiseGemmDescriptor<GridwiseGemm>);
|
||||
|
||||
// Describe Aand B block transfer thread cluster lengths.
|
||||
struct BlockTransfer
|
||||
{
|
||||
size_t k0;
|
||||
size_t m_n;
|
||||
size_t k1;
|
||||
};
|
||||
static_assert(ckb::BlockTransferDescriptor<BlockTransfer>);
|
||||
|
||||
// Describe C block transfer thread cluster lengths.
|
||||
struct ThreadCluster
|
||||
{
|
||||
size_t m_block;
|
||||
size_t m_wave_per_xdl;
|
||||
size_t n_block;
|
||||
size_t n_wave_per_xdl;
|
||||
};
|
||||
static_assert(ThreadClusterDescriptor<ThreadCluster>);
|
||||
|
||||
struct LdsTransfer
|
||||
{
|
||||
size_t src_vector_dim;
|
||||
size_t src_scalar_per_vector;
|
||||
size_t lds_dst_scalar_per_vector;
|
||||
bool is_direct_load;
|
||||
bool lds_padding;
|
||||
};
|
||||
static_assert(LdsTransferDescriptor<LdsTransfer>);
|
||||
|
||||
struct Epilogue
|
||||
{
|
||||
size_t m_xdl_per_wave_per_shuffle;
|
||||
size_t n_xdl_per_wave_per_shuffle;
|
||||
size_t scalar_per_vector;
|
||||
};
|
||||
static_assert(EpilogueDescriptor<Epilogue>);
|
||||
|
||||
struct AccessOrder
|
||||
{
|
||||
std::array<size_t, 3> order;
|
||||
};
|
||||
static_assert(AccessOrderDescriptor<AccessOrder>);
|
||||
|
||||
struct BlockTransferABC
|
||||
{
|
||||
BlockTransfer block_transfer_a;
|
||||
BlockTransfer block_transfer_b;
|
||||
ThreadCluster thread_cluster_dims_c;
|
||||
LdsTransfer lds_transfer_a;
|
||||
LdsTransfer lds_transfer_b;
|
||||
Epilogue epilogue_c;
|
||||
AccessOrder block_transfer_access_order_a;
|
||||
AccessOrder block_transfer_access_order_b;
|
||||
AccessOrder src_access_order_a;
|
||||
AccessOrder src_access_order_b;
|
||||
};
|
||||
|
||||
struct ConvAlgorithm
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
BlockGemmPipelineVersion pipeline_version;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesGridwiseGemm<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesSourceAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesGemmPipelineVersion<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
23
experimental/builder/test/impl/conv_signature_types.hpp
Normal file
23
experimental/builder/test/impl/conv_signature_types.hpp
Normal file
@@ -0,0 +1,23 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_signature_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder::test {
|
||||
|
||||
template <typename GroupConvLayout>
|
||||
struct ConvSignature
|
||||
{
|
||||
int spatial_dim;
|
||||
ConvDirection direction;
|
||||
GroupConvLayout layout;
|
||||
DataType data_type;
|
||||
ElementwiseOperation elementwise_operation;
|
||||
};
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout1D>>);
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout2D>>);
|
||||
static_assert(ConvSignatureDescriptor<ConvSignature<GroupConvLayout3D>>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
103
experimental/builder/test/utils/ckb_conv_test_common.hpp
Normal file
103
experimental/builder/test/utils/ckb_conv_test_common.hpp
Normal file
@@ -0,0 +1,103 @@
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
// Common test base class
|
||||
class FwdConvBuilderTestBase : public ::testing::Test
|
||||
{
|
||||
};
|
||||
|
||||
// Common test implementation
|
||||
template <auto FwdConvSignature,
|
||||
ThreadBlock FwdThreadBlock,
|
||||
BlockGemmPipelineVersion FwdPipelineVersion,
|
||||
ConvFwdSpecialization FwdConvSpecialization>
|
||||
constexpr void run_test()
|
||||
{
|
||||
constexpr GridwiseGemm FwdGemmParams{.ak1 = 8,
|
||||
.bk1 = 8,
|
||||
.m_per_xdl = 32,
|
||||
.n_per_xdl = 32,
|
||||
.m_xdl_per_wave = 4,
|
||||
.n_xdl_per_wave = 4};
|
||||
|
||||
constexpr BlockTransferABC FwdBlockTransfer{.block_transfer_a = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.block_transfer_b = {.k0 = 4, .m_n = 64, .k1 = 1},
|
||||
.thread_cluster_dims_c = {.m_block = 1,
|
||||
.m_wave_per_xdl = 32,
|
||||
.n_block = 1,
|
||||
.n_wave_per_xdl = 8},
|
||||
.lds_transfer_a = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 2,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.lds_transfer_b = {.src_vector_dim = 2,
|
||||
.src_scalar_per_vector = 8,
|
||||
.lds_dst_scalar_per_vector = 8,
|
||||
.is_direct_load = false,
|
||||
.lds_padding = false},
|
||||
.epilogue_c = {.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.scalar_per_vector = 8},
|
||||
.block_transfer_access_order_a = {1, 0, 2},
|
||||
.block_transfer_access_order_b = {1, 0, 2},
|
||||
.src_access_order_a = {1, 0, 2},
|
||||
.src_access_order_b = {1, 0, 2}};
|
||||
|
||||
constexpr ConvAlgorithm FwdConvAlgorithm{.thread_block = FwdThreadBlock,
|
||||
.gridwise_gemm = FwdGemmParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.pipeline_version = FwdPipelineVersion,
|
||||
.fwd_specialization = FwdConvSpecialization};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
EXPECT_TRUE(kernel_string.starts_with("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3"));
|
||||
|
||||
// Verify pipeline version is correct
|
||||
if(FwdPipelineVersion == BlockGemmPipelineVersion::V1)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v1") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V3)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v3") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V4)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v4") != std::string::npos);
|
||||
else if(FwdPipelineVersion == BlockGemmPipelineVersion::V5)
|
||||
EXPECT_TRUE(kernel_string.find("BlkGemmPipelineVersion: v5") != std::string::npos);
|
||||
|
||||
// Verify specialization is correct
|
||||
if(FwdConvSpecialization == ConvFwdSpecialization::DEFAULT)
|
||||
EXPECT_TRUE(kernel_string.find("Default") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_1X1_STRIDE1_PAD0)
|
||||
EXPECT_TRUE(kernel_string.find("Filter1x1Stride1Pad0") != std::string::npos);
|
||||
else if(FwdConvSpecialization == ConvFwdSpecialization::FILTER_3x3)
|
||||
EXPECT_TRUE(kernel_string.find("Filter3x3") != std::string::npos);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
}
|
||||
|
||||
// Common thread block configurations
|
||||
constexpr ThreadBlock DefaultThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 256, .n = 256, .k = 32}};
|
||||
|
||||
constexpr ThreadBlock SmallThreadBlock{.block_size = 256,
|
||||
.tile_size = {.m = 128, .n = 128, .k = 32}};
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
@@ -552,6 +552,8 @@ struct PassThrough
|
||||
{
|
||||
y = type_convert<bf8_t>(x);
|
||||
}
|
||||
|
||||
static constexpr const char* name = "PassThrough";
|
||||
};
|
||||
|
||||
struct UnaryConvert
|
||||
|
||||
Reference in New Issue
Block a user