Merge commit '04612c30ceab818cd6c03a3e833a6c6d1a21dafa' into develop

This commit is contained in:
assistant-librarian[bot]
2025-12-08 11:12:53 +00:00
parent 66f05c1fbf
commit e5a3277261
79 changed files with 1608 additions and 232 deletions

View File

@@ -6,6 +6,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
### Added
* Added support for explicit GEMM in CK_TILE grouped convolution forward and backward weight.
* Added TF32 convolution support on gfx942 and gfx950 in CK. It could be enabled/disabled via `DTYPES` of "tf32".
### Changed

View File

@@ -92,6 +92,10 @@ if (DTYPES)
add_definitions(-DCK_ENABLE_FP32)
set(CK_ENABLE_FP32 "ON")
endif()
if (DTYPES MATCHES "tf32")
# definition will be added based on the GPU target in the following section
set(CK_ENABLE_TF32 "ON")
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-DCK_ENABLE_FP64)
set(CK_ENABLE_FP64 "ON")
@@ -106,6 +110,7 @@ else()
set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_TF32 "ON")
set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON")
set(CK_ENABLE_FP8 "ON")
@@ -282,6 +287,15 @@ if (SUPPORTED_GPU_TARGETS MATCHES "gfx950")
set(CK_GFX950_SUPPORT "ON")
endif()
if ((SUPPORTED_GPU_TARGETS MATCHES "gfx942" OR SUPPORTED_GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
add_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "ON")
else()
message(STATUS "Disabling TF32 instances")
remove_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "OFF")
endif()
option(CK_USE_FP8_ON_UNSUPPORTED_ARCH "Enable FP8 GEMM instances on older architectures" OFF)
if(CK_USE_FP8_ON_UNSUPPORTED_ARCH AND (SUPPORTED_GPU_TARGETS MATCHES "gfx90a" OR SUPPORTED_GPU_TARGETS MATCHES "gfx908"))
add_definitions(-DCK_USE_FP8_ON_UNSUPPORTED_ARCH)
@@ -651,6 +665,9 @@ IF(IS_DIRECTORY "${PROJECT_SOURCE_DIR}/library/src/tensor_operation_instance/gpu
if(("${cmake_instance}" MATCHES "fp32" OR "${cmake_instance}" MATCHES "_f32") AND DTYPES MATCHES "fp32")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
set(add_inst 1)
endif()

View File

@@ -187,7 +187,7 @@ limit the number of threads. For example, if you have a 128-core CPU and 128 Gb
Additional cmake flags can be used to significantly speed-up the build:
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;fp16;fp8;bf16;int8" to build
* `DTYPES` (default is not set) can be set to any subset of "fp64;fp32;tf32;fp16;fp8;bf16;int8" to build
instances of select data types only. The main default data types are fp32 and fp16; you can safely skip
other data types.

View File

@@ -27,6 +27,9 @@ if (DTYPES)
add_definitions(-DCK_ENABLE_FP32)
set(CK_ENABLE_FP32 "ON")
endif()
if (DTYPES MATCHES "tf32")
set(CK_ENABLE_TF32 "ON")
endif()
if (DTYPES MATCHES "fp64")
add_definitions(-DCK_ENABLE_FP64)
set(CK_ENABLE_FP64 "ON")
@@ -41,6 +44,7 @@ else()
set(CK_ENABLE_INT8 "ON")
set(CK_ENABLE_FP16 "ON")
set(CK_ENABLE_FP32 "ON")
set(CK_ENABLE_TF32 "ON")
set(CK_ENABLE_FP64 "ON")
set(CK_ENABLE_BF16 "ON")
if (GPU_TARGETS MATCHES "gfx94")
@@ -67,6 +71,14 @@ if (GPU_TARGETS)
add_definitions(-DCK_USE_FNUZ_FP8)
set(CK_USE_FNUZ_FP8 "ON")
endif()
if ((GPU_TARGETS MATCHES "gfx942" OR GPU_TARGETS MATCHES "gfx95") AND CK_ENABLE_TF32)
add_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "ON")
else()
message(STATUS "Disabling TF32 instances for this target")
remove_definitions(-DCK_ENABLE_TF32)
set(CK_ENABLE_TF32 "OFF")
endif()
else()
add_definitions(-DCK_USE_WMMA -DCK_USE_XDL)
set(CK_USE_XDL "ON")

View File

@@ -95,6 +95,47 @@ concept AccessOrderDescriptor = requires(T t) {
{ t.order } -> std::convertible_to<std::array<size_t, 3>>;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
// size is deduced from block gemm structure).
template <typename T>
concept TileThreadBlockDescriptor = requires(T t) {
{ t.tile_size.m } -> std::convertible_to<size_t>;
{ t.tile_size.n } -> std::convertible_to<size_t>;
{ t.tile_size.k } -> std::convertible_to<size_t>;
};
// Concept for thread block dimensions for a GEMM problem for CK Tile (Block
// size is deduced from block gemm structure).
template <typename T>
concept TileTransferDescriptor = requires(T t) {
{ t.a_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.b_scalar_per_vector } -> std::convertible_to<size_t>;
{ t.c_scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
template <typename T>
concept TileBlockGemmDescriptor = requires(T t) {
{ t.warps.m } -> std::convertible_to<int>;
{ t.warps.n } -> std::convertible_to<int>;
{ t.warps.k } -> std::convertible_to<int>;
{ t.warp_tile.m } -> std::convertible_to<int>;
{ t.warp_tile.n } -> std::convertible_to<int>;
{ t.warp_tile.k } -> std::convertible_to<int>;
{ t.double_smem_buffer } -> std::convertible_to<bool>;
{ t.num_wave_groups } -> std::convertible_to<int>;
{ t.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ t.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept to check if struct specifies optimizations (CK Tile).
template <typename T>
concept TileOptimizationsDescriptor = requires(T t) {
{ t.num_groups_to_merge } -> std::convertible_to<int>;
{ t.split_image } -> std::convertible_to<bool>;
{ t.explicit_gemm } -> std::convertible_to<bool>;
};
// Base requirement for all ConvAlgorithm concepts, i.e., all conv algorithm concepts must meet this
// concept.
template <typename T>
@@ -110,6 +151,12 @@ concept SpecifiesThreadBlock = requires {
{ T::thread_block } -> ThreadBlockDescriptor;
};
// Concept to check if struct specifies thread block info (CK Tile).
template <typename T>
concept SpecifiesTileThreadBlock = requires {
{ T::thread_block } -> TileThreadBlockDescriptor;
};
// Concept to check if a struct specifies gridwise XDL GEMM info.
template <typename T>
concept SpecifiesGridwiseXdlGemm = requires {
@@ -130,6 +177,14 @@ concept SpecifiesBlockTransfer = requires(T t) {
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
};
// Concept to check if a struct specifies convolution scalar per vector infor for A, B and C.
template <typename T>
concept SpecifiesTileTransfer = requires(T t) {
{ T::transfer.a_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.b_scalar_per_vector } -> std::convertible_to<size_t>;
{ T::transfer.c_scalar_per_vector } -> std::convertible_to<size_t>;
};
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
template <typename T>
concept SpecifiesLdsTransfer = requires(T t) {
@@ -159,8 +214,36 @@ concept SpecifiesBlockGemm = requires {
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
template <typename T>
concept SpecifiesFwdConcSpecialization = requires {
concept SpecifiesTileBlockGemm = requires {
{ T::block_gemm.warps.m } -> std::convertible_to<int>;
{ T::block_gemm.warps.n } -> std::convertible_to<int>;
{ T::block_gemm.warps.k } -> std::convertible_to<int>;
{ T::block_gemm.warp_tile.m } -> std::convertible_to<int>;
{ T::block_gemm.warp_tile.n } -> std::convertible_to<int>;
{ T::block_gemm.warp_tile.k } -> std::convertible_to<int>;
{ T::block_gemm.double_smem_buffer } -> std::convertible_to<bool>;
{ T::block_gemm.num_wave_groups } -> std::convertible_to<int>;
{ T::block_gemm.pipeline_version } -> std::convertible_to<PipelineVersion>;
{ T::block_gemm.scheduler } -> std::convertible_to<PipelineScheduler>;
};
// Concept to check if struct specifies block GEMM (CK Tile).
template <typename T>
concept SpecifiesTileOptimizations = requires {
{ T::optimizations.num_groups_to_merge } -> std::convertible_to<int>;
{ T::optimizations.split_image } -> std::convertible_to<bool>;
{ T::optimizations.explicit_gemm } -> std::convertible_to<bool>;
};
template <typename T>
concept SpecifiesTileConvSpecialization = requires {
{ T::specialization } -> std::convertible_to<TileConvSpecialization>;
};
template <typename T>
concept SpecifiesFwdConvSpecialization = requires {
{ T::fwd_specialization } -> std::convertible_to<ConvFwdSpecialization>;
};

View File

@@ -15,6 +15,11 @@ concept InputVectorTransferLimits = requires {
Value.lds_dst_scalar_per_vector > 0;
};
// Limits for input and output vector transfer (CK Tile).
template <auto Value>
concept TileInputOutputVectorTransferLimits =
requires { requires Value.a > 0 && Value.b > 0 && Value.c > 0; };
// Limits for output vector transfer.
template <auto Value>
concept OutputVectorTransferLimits = requires {

View File

@@ -59,6 +59,7 @@
#include "ck_tile/builder/factory/conv_fwd_wmma_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_dl_factory.hpp"
#include "ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp"
#include "ck_tile/builder/factory/conv_tile_factory.hpp"
namespace ck_tile::builder::factory {
@@ -81,6 +82,15 @@ namespace ck_tile::builder::factory {
//
// TODO: Make this dispatch logic much more robust and clear for users.
// CK Tile kernel
template <typename T>
consteval bool IsTileAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesTileThreadBlock<T> && SpecifiesTileTransfer<T> &&
SpecifiesTileConvSpecialization<T> && SpecifiesTileBlockGemm<T> &&
SpecifiesTileOptimizations<T>;
}
// XDL-based kernel with V3 pipeline structure (newer block GEMM pipeline)
template <typename T>
consteval bool IsXdlV3Algorithm()
@@ -88,7 +98,7 @@ consteval bool IsXdlV3Algorithm()
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesBlockGemm<T>;
}
@@ -99,7 +109,7 @@ consteval bool IsXdlAlgorithm()
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseXdlGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesNumPrefetchStages<T> && SpecifiesNumGroupsToMerge<T> &&
SpecifiesLoopScheduler<T>;
}
@@ -111,7 +121,7 @@ consteval bool IsWmmaAlgorithm()
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> && SpecifiesGridwiseWmmaGemm<T> &&
SpecifiesBlockTransfer<T> && SpecifiesLdsTransfer<T> &&
SpecifiesThreadClusterAccessOrder<T> && SpecifiesSourceAccessOrder<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesNumPrefetchStages<T> && SpecifiesLoopScheduler<T>;
}
@@ -120,7 +130,7 @@ template <typename T>
consteval bool IsDlAlgorithm()
{
return ConvAlgorithmDescriptor<T> && SpecifiesThreadBlock<T> &&
SpecifiesFwdConcSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesFwdConvSpecialization<T> && SpecifiesGemmSpecialization<T> &&
SpecifiesDlThreadConfig<T> && SpecifiesDlThreadCluster<T> &&
SpecifiesDlBlockTransfer<T> && SpecifiesDlEpilogue<T>;
}
@@ -137,10 +147,15 @@ template <ConvSignatureDescriptor auto SIGNATURE,
StringLiteral VERSION>
constexpr auto make_conv_instance()
{
if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
using AlgoType = std::remove_const_t<decltype(ALGORITHM)>;
// CK Tile supports common factory for each direction
if constexpr(IsTileAlgorithm<AlgoType>())
{
return typename ConvTileFactory<SIGNATURE, ALGORITHM, VERSION>::Instance{};
}
else if constexpr(ConvDirectionIsForward<SIGNATURE>)
{
if constexpr(IsXdlV3Algorithm<AlgoType>())
{
return typename ConvFwdXdlV3Factory<SIGNATURE, ALGORITHM, VERSION>::Instance{};

View File

@@ -7,11 +7,11 @@
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -8,12 +8,12 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -8,12 +8,12 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -8,12 +8,12 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -8,12 +8,12 @@
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace ck_tile::builder::factory {

View File

@@ -0,0 +1,131 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/epilogue.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/grouped_convolution.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/conv_algorithm_limits.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_block_transfer.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck_tile/conv_tile_kernel_directions.hpp"
namespace ck_tile::builder::factory {
// Factory for CK Tile Grouped Convolution kernels.
template <ConvSignatureDescriptor auto SIGNATURE,
ConvAlgorithmDescriptor auto ALGORITHM,
StringLiteral VERSION>
struct ConvTileFactory
{
static constexpr size_t SPATIAL_DIM = SIGNATURE.spatial_dim;
using Layouts = internal::TileConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Types = internal::TileConvTensorTypes<SIGNATURE.data_type>;
using Ops = internal::TileElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static constexpr auto CONV_SPECIALIZATION = internal::SetTileConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetTileThreadBlockInfo<ALGORITHM>();
static constexpr auto BLOCK_GEMM = internal::SetTileBlockGemm<ALGORITHM>();
static constexpr auto OPTIMIZATIONS = internal::SetTileOptimizations<ALGORITHM>();
static constexpr auto SCALAR_PER_VECTOR = internal::SetTileBlockTransfer<ALGORITHM>();
static constexpr auto CONV_DIRECTION = internal::SetTileConvDirection<SIGNATURE>();
// Check limits for the algorithm parameters.
// TODO: Add more limits checks as needed.
static_assert(TileInputOutputVectorTransferLimits<SCALAR_PER_VECTOR>);
using GroupedConvTraitsType = ck_tile::GroupedConvTraits<SPATIAL_DIM,
CONV_SPECIALIZATION,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
SCALAR_PER_VECTOR.a,
SCALAR_PER_VECTOR.b,
SCALAR_PER_VECTOR.c,
OPTIMIZATIONS.num_groups_to_merge,
OPTIMIZATIONS.split_image,
OPTIMIZATIONS.explicit_gemm>;
using GemmShape = ck_tile::TileGemmShape<
ck_tile::sequence<BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k>,
ck_tile::sequence<BLOCK_GEMM.warps.m, BLOCK_GEMM.warps.n, BLOCK_GEMM.warps.k>,
ck_tile::sequence<BLOCK_GEMM.warp_tile.m, BLOCK_GEMM.warp_tile.n, BLOCK_GEMM.warp_tile.k>>;
using TilePartitioner = ck_tile::GemmSpatiallyLocalTilePartitioner<
GemmShape,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerGroupNum,
GroupedConvTraitsType::FixedGemmParams::TilePartitionerM01>;
using GemmUniversalTraits = ck_tile::TileGemmUniversalTraits<
GroupedConvTraitsType::FixedGemmParams::kPadM,
GroupedConvTraitsType::FixedGemmParams::kPadN,
GroupedConvTraitsType::FixedGemmParams::kPadK,
BLOCK_GEMM.double_smem_buffer,
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::AsLayout,
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::BsLayout,
typename GroupedConvTraitsType::template GemmLayouts<CONV_DIRECTION>::CLayout,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
GroupedConvTraitsType::FixedGemmParams::UseStructuredSparsity,
GroupedConvTraitsType::FixedGemmParams::Persistent,
BLOCK_GEMM.num_wave_groups>;
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
GemmShape,
GemmUniversalTraits,
BLOCK_GEMM.scheduler,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Types::EDataType,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
GroupedConvTraitsType::VectorSizeA,
GroupedConvTraitsType::VectorSizeB>;
using GemmPipeline = typename internal::TilePipelineType<
BLOCK_GEMM.pipeline_version>::template GemmPipeline<UniversalGemmProblem>;
using ConvEpilogue = ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<
typename Types::ADataType,
typename Types::BDataType,
typename Types::DsDataTypes,
typename Types::AccDataType,
typename Types::EDataType,
typename GroupedConvTraitsType::ImplicitGemmDsLayout,
typename GroupedConvTraitsType::FixedGemmParams::ELayout,
typename Ops::CDEElementwiseOp,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK_GEMM.warps.m,
BLOCK_GEMM.warps.n,
BLOCK_GEMM.warp_tile.m,
BLOCK_GEMM.warp_tile.n,
BLOCK_GEMM.warp_tile.k,
GroupedConvTraitsType::FixedGemmParams::TransposeC,
// TODO:: This template parameter will be moved inside the kernel
ck_tile::memory_operation_enum::set,
BLOCK_GEMM.num_wave_groups,
GroupedConvTraitsType::FixedGemmParams::FixedVectorSize,
SCALAR_PER_VECTOR.c>>;
using Instance = typename internal::GroupedConvolutionTileKernel<SIGNATURE,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>::Instance;
};
} // namespace ck_tile::builder::factory

View File

@@ -0,0 +1,25 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
namespace ck_tile::builder::factory::internal {
struct TileScalarPerVector
{
size_t a = 0;
size_t b = 0;
size_t c = 0;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr TileScalarPerVector SetTileBlockTransfer()
{
return TileScalarPerVector{.a = ALGORITHM.transfer.a_scalar_per_vector,
.b = ALGORITHM.transfer.b_scalar_per_vector,
.c = ALGORITHM.transfer.c_scalar_per_vector};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,62 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/builder/builder_utils.hpp"
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder::factory::internal {
template <ElementwiseOperation Op>
struct ElementwiseOpToCKTile
{
static_assert(sizeof(UnsupportedEnumValue<Op>) == 0,
"Unsupported elementwise operation conversion to CK.");
};
template <>
struct ElementwiseOpToCKTile<ElementwiseOperation::PASS_THROUGH>
{
using Op = ck_tile::element_wise::PassThrough;
};
template <>
struct ElementwiseOpToCKTile<ElementwiseOperation::SCALE>
{
using Op = ck_tile::element_wise::Scale;
};
template <>
struct ElementwiseOpToCKTile<ElementwiseOperation::CLAMP>
{
using Op = ck_tile::element_wise::Clamp;
};
template <auto TensorDesc>
consteval auto GetTileElementwiseOp()
{
if constexpr(HasTensorOp<decltype(TensorDesc)>)
{
constexpr auto op = TensorDesc.operation.elementwise_operation;
return ElementwiseOpToCKTile<op>{};
}
else
{
return ElementwiseOpToCKTile<ElementwiseOperation::PASS_THROUGH>{};
}
}
template <auto Sig>
struct TileElementwiseOps
{
static constexpr auto input_op = GetTileElementwiseOp<Sig.input>();
static constexpr auto weight_op = GetTileElementwiseOp<Sig.weight>();
static constexpr auto output_op = GetTileElementwiseOp<Sig.output>();
using AElementwiseOp = typename decltype(input_op)::Op;
using BElementwiseOp = typename decltype(weight_op)::Op;
using CDEElementwiseOp = typename decltype(output_op)::Op;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,88 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_signature_concepts.hpp"
namespace ck_tile::builder::factory::internal {
template <ConvSignatureDescriptor auto SIGNATURE,
typename GroupedConvTraitsType,
typename TilePartitioner,
typename GemmPipeline,
typename ConvEpilogue>
struct GroupedConvolutionTileKernel
{
static_assert(false, "Unknown Direction");
};
template <ConvSignatureDescriptor auto SIGNATURE,
typename GroupedConvTraitsType,
typename TilePartitioner,
typename GemmPipeline,
typename ConvEpilogue>
requires ConvDirectionIsForward<SIGNATURE>
struct GroupedConvolutionTileKernel<SIGNATURE,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>
{
using Instance = ck_tile::GroupedConvolutionForwardKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
};
template <ConvSignatureDescriptor auto SIGNATURE,
typename GroupedConvTraitsType,
typename TilePartitioner,
typename GemmPipeline,
typename ConvEpilogue>
requires ConvDirectionIsBackwardData<SIGNATURE>
struct GroupedConvolutionTileKernel<SIGNATURE,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>
{
using Instance = ck_tile::GroupedConvolutionBackwardDataKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
};
template <ConvSignatureDescriptor auto SIGNATURE,
typename GroupedConvTraitsType,
typename TilePartitioner,
typename GemmPipeline,
typename ConvEpilogue>
requires ConvDirectionIsBackwardWeight<SIGNATURE>
struct GroupedConvolutionTileKernel<SIGNATURE,
GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>
{
using Instance = ck_tile::GroupedConvolutionBackwardWeightKernel<GroupedConvTraitsType,
TilePartitioner,
GemmPipeline,
ConvEpilogue>;
};
template <ConvSignatureDescriptor auto SIGNATURE>
consteval ck_tile::GroupedConvDirection SetTileConvDirection()
{
constexpr auto direction = SIGNATURE.direction;
using ck_tile_direction = ck_tile::GroupedConvDirection;
switch(direction)
{
case ConvDirection::FORWARD: return ck_tile_direction::FORWARD;
case ConvDirection::BACKWARD_DATA: return ck_tile_direction::BACKWARD_DATA;
case ConvDirection::BACKWARD_WEIGHT: return ck_tile_direction::BACKWARD_WEIGHT;
default: throw "Unknown Direction";
}
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,200 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/builder/conv_signature_concepts.hpp"
namespace ck_tile::builder::factory::internal {
using ALayout = ck_tile::tensor_layout::convolution::NWGC;
template <TensorLayout Layout>
struct LayoutToCKTile
{
static_assert(sizeof(UnsupportedEnumValue<Layout>) == 0,
"Unsupported layout conversion to CK.");
};
// Bias layouts
template <>
struct LayoutToCKTile<TensorLayout::G_K_strided>
{
using type = ck_tile::tensor_layout::convolution::G_K;
};
template <>
struct LayoutToCKTile<TensorLayout::GC>
{
using type = ck_tile::tensor_layout::convolution::GC;
};
template <>
struct LayoutToCKTile<TensorLayout::G_C_strided>
{
using type = ck_tile::tensor_layout::convolution::G_C;
};
// Input 1D
template <>
struct LayoutToCKTile<TensorLayout::NWGC>
{
using type = ck_tile::tensor_layout::convolution::NWGC;
};
template <>
struct LayoutToCKTile<TensorLayout::GNWC>
{
using type = ck_tile::tensor_layout::convolution::GNWC;
};
// Input 2D
template <>
struct LayoutToCKTile<TensorLayout::NHWGC>
{
using type = ck_tile::tensor_layout::convolution::NHWGC;
};
template <>
struct LayoutToCKTile<TensorLayout::GNHWC>
{
using type = ck_tile::tensor_layout::convolution::GNHWC;
};
// Input 3D
template <>
struct LayoutToCKTile<TensorLayout::NDHWGC>
{
using type = ck_tile::tensor_layout::convolution::NDHWGC;
};
template <>
struct LayoutToCKTile<TensorLayout::GNDHWC>
{
using type = ck_tile::tensor_layout::convolution::GNDHWC;
};
// Weight 1D
template <>
struct LayoutToCKTile<TensorLayout::GKXC>
{
using type = ck_tile::tensor_layout::convolution::GKXC;
};
template <>
struct LayoutToCKTile<TensorLayout::GKCX>
{
using type = ck_tile::tensor_layout::convolution::GKCX;
};
// Weight 2D
template <>
struct LayoutToCKTile<TensorLayout::GKYXC>
{
using type = ck_tile::tensor_layout::convolution::GKYXC;
};
template <>
struct LayoutToCKTile<TensorLayout::GKCYX>
{
using type = ck_tile::tensor_layout::convolution::GKCYX;
};
// Weight 3D
template <>
struct LayoutToCKTile<TensorLayout::GKCZYX>
{
using type = ck_tile::tensor_layout::convolution::GKCZYX;
};
template <>
struct LayoutToCKTile<TensorLayout::GKZYXC>
{
using type = ck_tile::tensor_layout::convolution::GKZYXC;
};
// Output 1D
template <>
struct LayoutToCKTile<TensorLayout::NWGK>
{
using type = ck_tile::tensor_layout::convolution::NWGK;
};
template <>
struct LayoutToCKTile<TensorLayout::GNWK>
{
using type = ck_tile::tensor_layout::convolution::GNWK;
};
// Output 2D
template <>
struct LayoutToCKTile<TensorLayout::NHWGK>
{
using type = ck_tile::tensor_layout::convolution::NHWGK;
};
template <>
struct LayoutToCKTile<TensorLayout::GNHWK>
{
using type = ck_tile::tensor_layout::convolution::GNHWK;
};
// Output 3D
template <>
struct LayoutToCKTile<TensorLayout::NDHWGK>
{
using type = ck_tile::tensor_layout::convolution::NDHWGK;
};
template <>
struct LayoutToCKTile<TensorLayout::GNDHWK>
{
using type = ck_tile::tensor_layout::convolution::GNDHWK;
};
template <TensorLayout Layout>
consteval auto TensorLayoutToCKTile()
{
return typename LayoutToCKTile<Layout>::type{};
}
struct EmptyAuxiliaryTileTensorLayout
{
using type = ck_tile::tuple<>;
};
template <auto AuxiliaryTileTensorConfigsArray, size_t... Indices>
consteval auto GetAuxiliaryTileTensorLayoutTuple(std::index_sequence<Indices...>)
{
return ck_tile::tuple<
decltype(TensorLayoutToCKTile<AuxiliaryTileTensorConfigsArray[Indices].layout>())...>{};
}
template <auto AuxiliaryTileTensorConfigsValue, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM>)
struct AuxiliaryTileTensorLayouts
{
static constexpr auto Size = AuxiliaryTileTensorConfigsValue.size();
using type = decltype(GetAuxiliaryTileTensorLayoutTuple<AuxiliaryTileTensorConfigsValue>(
std::make_index_sequence<Size>{}));
};
// TODO: Currently only the ouput tensor can have auxiliary tensors (e.g., bias).
template <auto Signature, size_t SPATIAL_DIM>
requires(HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
consteval auto GetAuxiliaryTileTensorLayouts()
{
return AuxiliaryTileTensorLayouts<Signature.output.operation.auxiliary_operand_configs,
SPATIAL_DIM>{};
}
template <auto Signature, size_t SPATIAL_DIM>
requires(!HasElementwiseOpWithAuxiliaryOperands<decltype(Signature.output)>)
consteval auto GetAuxiliaryTileTensorLayouts()
{
return EmptyAuxiliaryTileTensorLayout{};
}
template <auto Signature, size_t SPATIAL_DIM>
requires(ConvSpatialDim<SPATIAL_DIM> &&
ValidConvInputLayoutForSpatialDim<Signature.input.config.layout, SPATIAL_DIM> &&
ValidConvWeightLayoutForSpatialDim<Signature.weight.config.layout, SPATIAL_DIM> &&
ValidConvOutputLayoutForSpatialDim<Signature.output.config.layout, SPATIAL_DIM>)
struct TileConvTensorLayouts
{
using ALayout = decltype(TensorLayoutToCKTile<Signature.input.config.layout>());
using BLayout = decltype(TensorLayoutToCKTile<Signature.weight.config.layout>());
using ELayout = decltype(TensorLayoutToCKTile<Signature.output.config.layout>());
using DsLayout = decltype(GetAuxiliaryTileTensorLayouts<Signature, SPATIAL_DIM>())::type;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,87 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/builder/types.hpp"
#include "ck_tile/builder/builder_utils.hpp"
namespace ck_tile::builder::factory::internal {
// Type mappings from builder convolution data type to CK Tile tensor types.
template <DataType T>
struct TileConvTensorTypes
{
// This will trigger if a specialization for the given DataType is not found.
// We should always catch this in an earlier validation check.
static_assert(sizeof(UnsupportedEnumValue<T>) == 0,
"Internal error. Unsupported data type for convolution factory.");
};
template <>
struct TileConvTensorTypes<DataType::FP16>
{
using ADataType = ck_tile::half_t;
using AComputeType = ck_tile::half_t;
using BDataType = ck_tile::half_t;
using BComputeType = ck_tile::half_t;
using CShuffleDataType = ck_tile::half_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = ck_tile::half_t;
};
template <>
struct TileConvTensorTypes<DataType::BF16>
{
using ADataType = ck_tile::bf16_t;
using AComputeType = ck_tile::bf16_t;
using BDataType = ck_tile::bf16_t;
using BComputeType = ck_tile::bf16_t;
using CShuffleDataType = ck_tile::bf16_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = ck_tile::bf16_t;
};
template <>
struct TileConvTensorTypes<DataType::FP32>
{
using ADataType = float;
using AComputeType = float;
using BDataType = float;
using BComputeType = float;
using CShuffleDataType = float;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = float;
};
template <>
struct TileConvTensorTypes<DataType::I8>
{
using ADataType = int8_t;
using AComputeType = int8_t;
using BDataType = int8_t;
using BComputeType = int8_t;
using CShuffleDataType = int8_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = int32_t;
using EDataType = int8_t;
};
template <>
struct TileConvTensorTypes<DataType::FP8>
{
using ADataType = ck_tile::fp8_t;
using AComputeType = ck_tile::fp8_t;
using BDataType = ck_tile::fp8_t;
using BComputeType = ck_tile::fp8_t;
using CShuffleDataType = ck_tile::fp8_t;
using DsDataTypes = ck_tile::tuple<>;
using AccDataType = float;
using EDataType = ck_tile::fp8_t;
};
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,32 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
namespace ck_tile::builder::factory::internal {
// Convenience struct for a tuple of m, n, and k values.
struct TileBlockMNK
{
int m{};
int n{};
int k{};
};
struct TileConvBlock
{
TileBlockMNK per_block = {};
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
constexpr TileConvBlock SetTileThreadBlockInfo()
{
constexpr auto& TB = ALGORITHM.thread_block;
return TileConvBlock{
.per_block = {.m = TB.tile_size.m, .n = TB.tile_size.n, .k = TB.tile_size.k},
};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -0,0 +1,158 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/builder/conv_algorithm_concepts.hpp"
#include "ck_tile/builder/types.hpp"
namespace ck_tile::builder::factory::internal {
// Convenience struct for a tuple of m, n, and k values.
struct TileBlockGemmMNK
{
int m{};
int n{};
int k{};
};
struct TileBlockGemmSpec
{
TileBlockGemmMNK warps = {};
TileBlockGemmMNK warp_tile = {};
bool double_smem_buffer = false;
int num_wave_groups = 1;
ck_tile::GemmPipeline pipeline_version;
ck_tile::GemmPipelineScheduler scheduler;
};
struct TileOptimizations
{
int num_groups_to_merge = 1;
bool split_image = false;
bool explicit_gemm = false;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::GemmPipelineScheduler SetTileScheduler()
{
constexpr auto scheduler = ALGORITHM.block_gemm.scheduler;
using ck_tile_sched = ck_tile::GemmPipelineScheduler;
switch(scheduler)
{
case PipelineScheduler::DEFAULT: return ck_tile_sched::Default;
case PipelineScheduler::INTERWAVE: return ck_tile_sched::Interwave;
case PipelineScheduler::INTRAWAVE: return ck_tile_sched::Intrawave;
default: throw "Unknown PipelineScheduler";
}
}
template <ck_tile::GemmPipeline PipelineId>
struct TilePipelineType
{
static_assert(false, "Unknown PipelineScheduler");
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::BASIC_V1>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAGmemBGmemCRegV1<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::MEMORY>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V3>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV3<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V4>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV4<PipelineProblem>;
};
template <>
struct TilePipelineType<ck_tile::GemmPipeline::COMPUTE_V5>
{
template <typename PipelineProblem>
using GemmPipeline = ck_tile::GemmPipelineAgBgCrCompV5<PipelineProblem>;
};
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::GemmPipeline SetTileBlockGemmPipelineVersion()
{
constexpr auto version = ALGORITHM.block_gemm.pipeline_version;
using ck_tile_pipeline = ck_tile::GemmPipeline;
switch(version)
{
case PipelineVersion::V1: return ck_tile_pipeline::BASIC_V1;
case PipelineVersion::V2: return ck_tile_pipeline::MEMORY;
case PipelineVersion::V3: return ck_tile_pipeline::COMPUTE_V3;
case PipelineVersion::V4: return ck_tile_pipeline::COMPUTE_V4;
case PipelineVersion::V5: return ck_tile_pipeline::COMPUTE_V5;
case PipelineVersion::WEIGHT_ONLY:
throw "PipelineVersion::WEIGHT_ONLY is not supported for block GEMM pipeline version.";
default: throw "Unknown block GEMM PipelineVersion";
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval ck_tile::ConvolutionSpecialization SetTileConvSpecialization()
{
constexpr auto specialization = ALGORITHM.specialization;
using ck_tile_conv_spec = ck_tile::ConvolutionSpecialization;
switch(specialization)
{
case TileConvSpecialization::DEFAULT: return ck_tile_conv_spec::Default;
case TileConvSpecialization::FILTER_1X1_PAD0: return ck_tile_conv_spec::Filter1x1Pad0;
case TileConvSpecialization::FILTER_1X1_STRIDE1_PAD0:
return ck_tile_conv_spec::Filter1x1Stride1Pad0;
case TileConvSpecialization::FILTER_3x3: return ck_tile_conv_spec::Filter3x3;
default: throw "Unknown ConvFwdSpecialization";
}
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval TileBlockGemmSpec SetTileBlockGemm()
{
constexpr auto& BG = ALGORITHM.block_gemm;
constexpr bool double_smem_buffer = BG.double_smem_buffer;
constexpr int num_wave_groups = BG.num_wave_groups;
constexpr ck_tile::GemmPipeline pipeline_version = SetTileBlockGemmPipelineVersion<ALGORITHM>();
constexpr ck_tile::GemmPipelineScheduler scheduler = SetTileScheduler<ALGORITHM>();
return TileBlockGemmSpec{
.warps = {.m = BG.warps.m, .n = BG.warps.n, .k = BG.warps.k},
.warp_tile = {.m = BG.warp_tile.m, .n = BG.warp_tile.n, .k = BG.warp_tile.k},
.double_smem_buffer = double_smem_buffer,
.num_wave_groups = num_wave_groups,
.pipeline_version = pipeline_version,
.scheduler = scheduler};
}
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval TileOptimizations SetTileOptimizations()
{
constexpr auto& OPT = ALGORITHM.optimizations;
return TileOptimizations{.num_groups_to_merge = OPT.num_groups_to_merge,
.split_image = OPT.split_image,
.explicit_gemm = OPT.explicit_gemm};
}
} // namespace ck_tile::builder::factory::internal

View File

@@ -145,6 +145,15 @@ enum struct GemmSpecialization
MNKOPadding
};
// Enums for the CK Tile convolution specialization.
enum class TileConvSpecialization
{
DEFAULT,
FILTER_1X1_PAD0,
FILTER_1X1_STRIDE1_PAD0,
FILTER_3x3
};
// Enums for the forward convolution specialization.
enum class ConvFwdSpecialization
{

View File

@@ -90,7 +90,7 @@ add_ck_builder_test(test_ckb_conv_builder
# Tests convolution trait selection and configuration
add_ck_builder_test(test_ckb_conv_traits
conv/test_conv_traits.cpp)
conv/ck/test_conv_traits.cpp)
# Tests convolution problem description and parameter handling
add_ck_builder_test(test_ckb_conv_description
@@ -119,19 +119,22 @@ add_ck_builder_test(test_ckb_instance_string
# Tests the forward convolution builder across multiple data types and dimensions.
# Individual tests are split into separate files to enable parallel compilation.
add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp
conv/test_ckb_conv_fwd_1d_fp16.cpp
conv/test_ckb_conv_fwd_1d_bf16.cpp
conv/test_ckb_conv_fwd_1d_i8.cpp
conv/test_ckb_conv_fwd_2d_fp8.cpp
conv/test_ckb_conv_fwd_2d_bf16.cpp
conv/test_ckb_conv_fwd_2d_fp16.cpp
conv/test_ckb_conv_fwd_2d_fp32.cpp
conv/test_ckb_conv_fwd_2d_dl_fp16.cpp
conv/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
conv/test_ckb_conv_fwd_3d_bf16.cpp
conv/test_ckb_conv_fwd_3d_fp16.cpp
conv/test_ckb_conv_fwd_3d_fp32.cpp
conv/ck/test_ckb_conv_fwd_2d_bf16_scaleadd_relu.cpp
conv/ck/test_ckb_conv_fwd_1d_fp16.cpp
conv/ck/test_ckb_conv_fwd_1d_bf16.cpp
conv/ck/test_ckb_conv_fwd_1d_i8.cpp
conv/ck/test_ckb_conv_fwd_2d_fp8.cpp
conv/ck/test_ckb_conv_fwd_2d_bf16.cpp
conv/ck/test_ckb_conv_fwd_2d_fp16.cpp
conv/ck/test_ckb_conv_fwd_2d_fp32.cpp
conv/ck/test_ckb_conv_fwd_2d_dl_fp16.cpp
conv/ck/test_ckb_conv_fwd_2d_large_tensor_fp16.cpp
conv/ck/test_ckb_conv_fwd_3d_bf16.cpp
conv/ck/test_ckb_conv_fwd_3d_fp16.cpp
conv/ck/test_ckb_conv_fwd_3d_fp32.cpp
conv/ck_tile/test_ckb_conv_fwd_2d_fp16_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_weight_2d_fp16_v3.cpp
conv/ck_tile/test_ckb_conv_bwd_data_2d_fp16_v3.cpp
)

View File

@@ -0,0 +1,52 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_tile_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_DATA,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_ck_tile_test<Builder>({
"grouped_convolution_backward_data",
"fp16",
"NHWGC_GKYXC_NHWGK",
"64x64x64",
"2x2",
"16x16x16",
// "4x4x4", // TODO: Enable this check
"Default",
"Intrawave",
"CShuffleEpilogue",
"set",
"pipeline_AgBgCrCompV3",
"DoubleSmemBuffer_0",
"NumWaveGroups_1",
"MergedGroups_1",
"SplitImage_0",
"ExplicitGemm_0",
});
}
} // namespace

View File

@@ -0,0 +1,52 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_tile_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::BACKWARD_WEIGHT,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_ck_tile_test<Builder>({
"grouped_convolution_backward_weight",
"fp16",
"NHWGC_GKYXC_NHWGK",
"64x64x64",
"2x2",
"16x16x16",
// "4x4x4", // TODO: Enable this check
"Default",
"Intrawave",
"CShuffleEpilogue",
"set",
"pipeline_AgBgCrCompV3",
"DoubleSmemBuffer_0",
"NumWaveGroups_1",
"MergedGroups_1",
"SplitImage_0",
"ExplicitGemm_0",
});
}
} // namespace

View File

@@ -0,0 +1,52 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "utils/ckb_conv_tile_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
using namespace ck_tile::builder::test_utils;
TEST(FwdConvInstances, Create_ConvAlgorithm_Tile_GroupedConvolutionKernel_2D_FP16_NHWGC)
{
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
.direction = ConvDirection::FORWARD,
.data_type = DataType::FP16,
.accumulation_data_type = DataType::FP32,
.input = {.config = {.layout = TensorLayout::NHWGC}},
.weight = {.config = {.layout = TensorLayout::GKYXC}},
.output = {.config = {.layout = TensorLayout::NHWGK}}};
constexpr auto FwdConvAlgorithm =
ConvAlgorithm_Tile_GroupedConvolutionKernel{}
.with_tile_specializations(TileConvSpecialization::DEFAULT)
.with_tile_thread_block(FwdTileThreadBlock_64x64x64)
.with_tile_block_gemm(TileBlockGemmDesc_16x16_v3_intrawave)
.with_tile_transfer(FwdTileTransfer_4x4x4)
.with_tile_optimizations(TileOptimizations{
.num_groups_to_merge = 1, .split_image = false, .explicit_gemm = false});
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_ck_tile_test<Builder>({
"grouped_convolution_forward",
"fp16",
"NHWGC_GKYXC_NHWGK",
"64x64x64",
"2x2",
"16x16x16",
// "4x4x4", // TODO: Enable this check
"Default",
"Intrawave",
"CShuffleEpilogue",
"set",
"pipeline_AgBgCrCompV3",
"DoubleSmemBuffer_0",
"NumWaveGroups_1",
"MergedGroups_1",
"SplitImage_0",
"ExplicitGemm_0",
});
}
} // namespace

View File

@@ -243,6 +243,73 @@ struct LargeTensorWrapper
ConvAlgorithmSpecialization::LARGE_TENSOR;
};
// Specify thread block dimensions for a GEMM (CK Tile).
struct TileThreadBlock
{
// Size of the submatrix problem in a thread block.
MNK<size_t> tile_size;
};
static_assert(ckb::TileThreadBlockDescriptor<TileThreadBlock>);
struct TileTransfer
{
size_t a_scalar_per_vector;
size_t b_scalar_per_vector;
size_t c_scalar_per_vector;
};
static_assert(ckb::TileTransferDescriptor<TileTransfer>);
struct TileBlockGemm
{
// Number of warps per each dimension.
MNK<int> warps;
// Number of data processed per each dimension for each XDL/WMMA instruction.
MNK<int> warp_tile;
// Double LDS buffer.
bool double_smem_buffer;
// Waves grouping (Ping-Pong scheduler).
int num_wave_groups;
PipelineVersion pipeline_version;
PipelineScheduler scheduler;
};
static_assert(ckb::TileBlockGemmDescriptor<TileBlockGemm>);
struct TileOptimizations
{
// Number of convolution groups processed per one workgroup
int num_groups_to_merge;
// Split image for large tensors
bool split_image;
// Explicit gemm for 1x1, stride=0, pad=0 cases
bool explicit_gemm;
};
static_assert(ckb::TileOptimizationsDescriptor<TileOptimizations>);
struct TileConvSpecialization_
{
TileConvSpecialization specialization;
};
struct TileThreadBlock_
{
TileThreadBlock thread_block;
};
struct TileTransfer_
{
TileTransfer transfer;
};
struct TileBlockGemm_
{
TileBlockGemm block_gemm;
};
struct TileOptimizations_
{
TileOptimizations optimizations;
};
// Factory
template <typename... Components>
@@ -339,6 +406,51 @@ struct ConvAlgorithmTemplate : Components...
result.transfer = t;
return result;
}
template <typename S>
constexpr auto with_tile_specializations(const S& s) const
{
static_assert(std::is_base_of_v<TileConvSpecialization_, ConvAlgorithmTemplate>);
auto result = *this;
result.specialization = s;
return result;
}
template <typename TB>
constexpr auto with_tile_thread_block(const TB& tb) const
{
static_assert(std::is_base_of_v<TileThreadBlock_, ConvAlgorithmTemplate>);
auto result = *this;
result.thread_block = tb;
return result;
}
template <typename BG>
constexpr auto with_tile_block_gemm(const BG& bg) const
{
static_assert(std::is_base_of_v<TileBlockGemm_, ConvAlgorithmTemplate>);
auto result = *this;
result.block_gemm = bg;
return result;
}
template <typename T>
constexpr auto with_tile_transfer(const T& t) const
{
static_assert(std::is_base_of_v<TileTransfer_, ConvAlgorithmTemplate>);
auto result = *this;
result.transfer = t;
return result;
}
template <typename O>
constexpr auto with_tile_optimizations(const O& o) const
{
static_assert(std::is_base_of_v<TileOptimizations_, ConvAlgorithmTemplate>);
auto result = *this;
result.optimizations = o;
return result;
}
};
// Algorithm types
@@ -361,4 +473,10 @@ using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
using ConvAlgorithm_Tile_GroupedConvolutionKernel = ConvAlgorithmTemplate<TileThreadBlock_,
TileBlockGemm_,
TileTransfer_,
TileConvSpecialization_,
TileOptimizations_>;
} // namespace ck_tile::builder::test

View File

@@ -4,7 +4,7 @@
#include <gtest/gtest.h>
#include <type_traits>
#include "ck_tile/builder/factory/helpers/conv_elementwise_op.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_elementwise_op.hpp"
namespace {

View File

@@ -4,7 +4,7 @@
#include <gtest/gtest.h>
#include <type_traits>
#include "ck_tile/builder/factory/helpers/conv_tensor_layout.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_layout.hpp"
#include "impl/conv_signature_types.hpp"
namespace {

View File

@@ -4,7 +4,7 @@
#include <gtest/gtest.h>
#include <type_traits>
#include "ck_tile/builder/factory/helpers/conv_tensor_type.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tensor_type.hpp"
namespace {

View File

@@ -2,7 +2,7 @@
// SPDX-License-Identifier: MIT
#include <gtest/gtest.h>
#include "ck_tile/builder/factory/helpers/conv_thread_block.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_thread_block.hpp"
namespace {

View File

@@ -3,7 +3,7 @@
#include <gtest/gtest.h>
#include "ck_tile/builder/factory/helpers/conv_tuning_params.hpp"
#include "ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp"
namespace {

View File

@@ -28,4 +28,20 @@ constexpr void run_test(const std::vector<std::string>& kernel_instance_componen
}
}
// Common CK Tile test implementation
template <typename Builder>
constexpr void run_ck_tile_test(const std::vector<std::string>& kernel_instance_components)
{
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);
std::cout << kernel_string << std::endl;
for(const auto& component : kernel_instance_components)
{
EXPECT_THAT(kernel_string, ::testing::HasSubstr(component));
}
}
} // namespace ck_tile::builder::test_utils

View File

@@ -0,0 +1,85 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "impl/conv_algorithm_types.hpp"
#include "impl/conv_signature_types.hpp"
#include "ck_tile/builder/conv_builder.hpp"
namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
constexpr TileTransfer FwdTileTransfer_1x1x1{
.a_scalar_per_vector = 1,
.b_scalar_per_vector = 1,
.c_scalar_per_vector = 1,
};
constexpr TileTransfer FwdTileTransfer_4x4x4{
.a_scalar_per_vector = 4,
.b_scalar_per_vector = 4,
.c_scalar_per_vector = 4,
};
constexpr TileTransfer FwdTileTransfer_8x8x8{
.a_scalar_per_vector = 8,
.b_scalar_per_vector = 8,
.c_scalar_per_vector = 8,
};
constexpr TileThreadBlock FwdTileThreadBlock_256x256x32{.tile_size = {.m = 256, .n = 256, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_256x128x32{.tile_size = {.m = 256, .n = 128, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_128x128x32{.tile_size = {.m = 128, .n = 128, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_128x128x16{.tile_size = {.m = 128, .n = 128, .k = 16}};
constexpr TileThreadBlock FwdTileThreadBlock_64x32x32{.tile_size = {.m = 64, .n = 32, .k = 32}};
constexpr TileThreadBlock FwdTileThreadBlock_64x64x64{.tile_size = {.m = 64, .n = 64, .k = 64}};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v1_intrawave = {
.warps = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V1,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v2_intrawave = {
.warps = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V2,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v3_intrawave = {
.warps = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V3,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v4_intrawave = {
.warps = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V4,
.scheduler = PipelineScheduler::INTRAWAVE};
constexpr TileBlockGemm TileBlockGemmDesc_16x16_v5_intrawave = {
.warps = {.m = 2, .n = 2, .k = 1},
.warp_tile = {.m = 16, .n = 16, .k = 16},
.double_smem_buffer = false,
.num_wave_groups = 1,
.pipeline_version = PipelineVersion::V5,
.scheduler = PipelineScheduler::INTRAWAVE};
} // namespace ck_tile::builder::test_utils

View File

@@ -55,6 +55,11 @@
#ifndef CK_ENABLE_FP32
#define CK_ENABLE_FP32 "ON"
#endif
#ifndef CK_ENABLE_TF32
#if defined(__gfx942__) || defined(__gfx95__)
#define CK_ENABLE_TF32 "ON"
#endif
#endif
#ifndef CK_ENABLE_FP64
#define CK_ENABLE_FP64 "ON"
#endif
@@ -85,6 +90,12 @@
#cmakedefine CK_ENABLE_FP32 @CK_ENABLE_FP32@
#endif
#ifndef CK_ENABLE_TF32
#if defined(__gfx942__) || defined(__gfx95__)
#cmakedefine CK_ENABLE_TF32 @CK_ENABLE_TF32@
#endif
#endif
#ifndef CK_ENABLE_FP64
#cmakedefine CK_ENABLE_FP64 @CK_ENABLE_FP64@
#endif

View File

@@ -176,8 +176,10 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
constexpr index_t WaveNumN = BlockGemmShape::BlockWarps::at(I1{});
return concat('_', "pipeline_AgBgCrCompV3",
concat('x', MPerBlock, NPerBlock, KPerBlock), BlockSize,
concat('x', GetVectorSizeA(), GetVectorSizeB(), GetVectorSizeC()),
concat('x', WaveNumM, WaveNumN),
concat('x', kPadM, kPadN, kPadK));
concat('x', kPadM, kPadN, kPadK),
Problem::GetName());
// clang-format on
}

View File

@@ -301,7 +301,12 @@ struct UniversalGemmPipelineProblem
return concat('_', "gemm_problem",
concat('x', kBlockSize),
concat('x', kPadM, kPadN, kPadK),
Scheduler);
Scheduler,
"NumWaveGroups",
NumWaveGroups,
"DoubleSmemBuffer",
DoubleSmemBuffer
);
// clang-format on
}
};

View File

@@ -560,16 +560,31 @@ struct GroupedConvolutionBackwardDataKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
return concat('_', "grouped_convolution_backward_data",
gemm_prec_str<InDataType, WeiDataType>(),
InLayout::name,
WeiLayout::name,
OutLayout::name,
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
EpiloguePipeline::GetName(),
getConvSpecializationString(ConvSpecialization),
"MergedGroups",
NumGroupsToMerge,
"SplitImage",
EnableSplitImage,
"ExplicitGemm",
GroupedConvTraitsType_::ExplicitGemm
);
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
#ifdef CK_EXPERIMENTAL_BUILDER
CK_TILE_HOST std::string GetInstanceString() const
{

View File

@@ -417,26 +417,31 @@ struct GroupedConvolutionBackwardWeightKernel
[[nodiscard]] CK_TILE_HOST static const std::string GetName()
{
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
static constexpr bool EnableSplitImage = GroupedConvTraitsType_::EnableSplitImage;
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
if (NumGroupsToMerge > 1) {
return concat('_', "grouped_convolution_backward_weight",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
} else {
return concat('_', "grouped_convolution_backward_weight",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(), "merge", NumGroupsToMerge);
}
return concat('_', "grouped_convolution_backward_weight",
gemm_prec_str<InDataType, WeiDataType>(),
InLayout::name,
WeiLayout::name,
OutLayout::name,
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(),
getConvSpecializationString(ConvSpecialization),
"MergedGroups",
NumGroupsToMerge,
"SplitImage",
EnableSplitImage,
"ExplicitGemm",
GroupedConvTraitsType_::ExplicitGemm
);
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
#ifdef CK_EXPERIMENTAL_BUILDER
CK_TILE_HOST std::string GetInstanceString() const
{

View File

@@ -594,26 +594,28 @@ struct GroupedConvolutionForwardKernel
{
constexpr auto NumGroupsToMerge = GroupedConvTraitsType_::NumGroupsToMerge;
// clang-format off
if (NumGroupsToMerge > 1) {
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(),
"merge",
NumGroupsToMerge);
} else {
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName());
}
return concat('_', "grouped_convolution_forward",
gemm_prec_str<InDataType, WeiDataType>(),
InLayout::name,
WeiLayout::name,
OutLayout::name,
"gemm",
GemmPipeline::GetName(),
"epilogue",
EpiloguePipeline::GetName(),
getConvSpecializationString(ConvSpecialization),
"MergedGroups",
NumGroupsToMerge,
"SplitImage",
EnableSplitImage,
"ExplicitGemm",
GroupedConvTraitsType_::ExplicitGemm
);
// clang-format on
}
[[nodiscard]] CK_TILE_HOST static const std::string GetTypeString() { return GetName(); }
#ifdef CK_EXPERIMENTAL_BUILDER
CK_TILE_HOST std::string GetInstanceString() const
{

View File

@@ -9,6 +9,13 @@
namespace ck_tile {
enum class GroupedConvDirection
{
FORWARD,
BACKWARD_DATA,
BACKWARD_WEIGHT
};
/// @brief The Grouped Conv kernel host arguments.
///
/// @par Overview
@@ -113,6 +120,36 @@ struct GroupedConvTraits
using BsLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
using CLayoutBwdWeight = ck_tile::tensor_layout::gemm::RowMajor;
template <GroupedConvDirection Direction>
struct GemmLayouts
{
static_assert(false, "Unsupported direction.");
};
template <>
struct GemmLayouts<GroupedConvDirection::FORWARD>
{
using AsLayout = AsLayoutFwd;
using BsLayout = BsLayoutFwd;
using CLayout = CLayoutFwd;
};
template <>
struct GemmLayouts<GroupedConvDirection::BACKWARD_DATA>
{
using AsLayout = AsLayoutBwdData;
using BsLayout = BsLayoutBwdData;
using CLayout = CLayoutBwdData;
};
template <>
struct GemmLayouts<GroupedConvDirection::BACKWARD_WEIGHT>
{
using AsLayout = AsLayoutBwdWeight;
using BsLayout = BsLayoutBwdWeight;
using CLayout = CLayoutBwdWeight;
};
template <ck_tile::index_t NumWaveGroups = 1>
using GroupedConvImplicitGemmTraitsFwd =
TileGemmTraits<true, true, true, AsLayoutFwd, BsLayoutFwd, CLayoutFwd, NumWaveGroups>;

View File

@@ -115,12 +115,12 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: this operator requires the same compute type");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_instances(
@@ -130,7 +130,9 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_tf32_optimized_loads_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_instances(
op_ptrs);
@@ -139,8 +141,8 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv2d_bwd_data_xdl_nhwgk_gkyxc_nhwgc_f32_optimized_loads_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
@@ -284,12 +286,12 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: this operator requires the same compute type");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_instances(
@@ -299,7 +301,9 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_optimized_loads_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
op_ptrs);
@@ -308,8 +312,8 @@ struct DeviceOperationInstanceFactory<
add_device_grouped_conv3d_bwd_data_xdl_ndhwgk_gkzyxc_ndhwgc_f32_tf32_optimized_loads_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_BF16
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&

View File

@@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_in
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
@@ -135,28 +137,30 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"ComputeTypeA and ComputeTypeB must be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_bilinear_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);

View File

@@ -44,7 +44,9 @@ void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_insta
PassThrough,
PassThrough,
Scale>>>& instances);
#endif
#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdDataMultipleD<3,
NDHWGK,
@@ -135,28 +137,30 @@ struct DeviceOperationInstanceFactory<
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
else if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
if constexpr(is_same_v<InDataType, F32> && is_same_v<WeiDataType, F32> &&
is_same_v<OutDataType, F32>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
" only support same compute type");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, F32>)
{
add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_f32_tf32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_BF16
else if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
if constexpr(is_same_v<InDataType, BF16> && is_same_v<WeiDataType, BF16> &&
is_same_v<OutDataType, BF16> && is_same_v<ComputeTypeA, BF16> &&
is_same_v<ComputeTypeB, BF16>)
{
add_device_grouped_conv3d_bwd_data_xdl_scale_ndhwgk_gkzyxc_ndhwgc_bf16_instances(
op_ptrs);

View File

@@ -347,12 +347,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InLayout, NHWGC> && is_same_v<WeiLayout, GKYXC> &&
is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: ComputeTypeA and ComputeTypeB should be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, float>)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
@@ -367,7 +367,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_pad0_pipev5_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
@@ -380,8 +382,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_bwd_weight_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_pad0_pipev5_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&
@@ -610,12 +612,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: ComputeTypeA and ComputeTypeB should be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
@@ -629,7 +631,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_pad0_pipev5_instances(
op_ptrs);
}
else if constexpr(is_same_v<ComputeTypeA, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
@@ -642,8 +646,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_bwd_weight_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_pad0_pipev5_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&

View File

@@ -62,6 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_
PassThrough,
Bilinear,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
@@ -151,24 +154,26 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: this operator requires the same compute type");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&

View File

@@ -62,7 +62,9 @@ void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_ins
PassThrough,
Scale,
PassThrough>>>& instances);
#endif
#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvBwdWeightMultipleD<3,
NDHWGC,
@@ -152,24 +154,26 @@ struct DeviceOperationInstanceFactory<
if constexpr(is_same_v<InLayout, NDHWGC> && is_same_v<WeiLayout, GKZYXC> &&
is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<ComputeTypeA, ComputeTypeB>,
"Error: this operator requires the same compute type");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeTypeA, TF32>)
{
add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeTypeA, float>)
{
add_device_grouped_conv3d_bwd_weight_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeTypeA, half_t> &&

View File

@@ -198,12 +198,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(NumDimSpatial == 2 && is_same_v<InLayout, NHWGC> &&
is_same_v<WeiLayout, GKYXC> && is_same_v<OutLayout, NHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same!");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_instances(op_ptrs);
@@ -219,7 +219,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
}
else if constexpr(is_same_v<AComputeType, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(op_ptrs);
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_16x16_instances(
@@ -235,8 +237,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<AComputeType, half_t> &&
@@ -451,10 +453,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, BComputeType> && is_same_v<BComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
@@ -472,7 +474,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, BComputeType> &&
is_same_v<BComputeType, float>)
{
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(op_ptrs);
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_16x16_instances(
@@ -488,8 +493,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP8
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&

View File

@@ -129,12 +129,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"A and B compute types should be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
@@ -153,7 +153,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
}
else if constexpr(is_same_v<AComputeType, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
@@ -170,8 +172,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_bias_bn_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
}
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
@@ -229,12 +231,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"A and B compute types should be the same");
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
@@ -253,7 +255,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
else if constexpr(is_same_v<AComputeType, TF32>)
#endif
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
@@ -270,8 +274,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_bias_bn_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
}
#endif // CK_USE_XDL

View File

@@ -129,12 +129,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
@@ -152,7 +152,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_bias_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
@@ -169,9 +171,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_bias_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
}
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
@@ -221,12 +222,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
@@ -244,7 +245,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_bias_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
@@ -261,9 +264,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_bias_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
}
#endif // CK_USE_XDL

View File

@@ -68,7 +68,9 @@ void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instanc
PassThrough,
PassThrough,
Bilinear>>>& instances);
#endif
#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -149,22 +151,24 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK> &&
DLayouts::Size() == 1 && is_same_v<tuple_element_t<0, DLayouts>, NDHWGK>)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeType, float>)
{
add_device_grouped_conv3d_fwd_xdl_bilinear_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)

View File

@@ -127,12 +127,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
@@ -150,7 +150,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_clamp_xdl_large_tensor_nhwgc_gkyxc_nhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_instances(
op_ptrs);
@@ -167,9 +169,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv2d_fwd_clamp_xdl_nhwgc_gkyxc_nhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
}
// layout NDHWGC/GKZYXC/NDHWGK
if constexpr(NumDimSpatial == 3 && is_same_v<InLayout, NDHWGC> &&
@@ -218,12 +219,12 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
op_ptrs);
}
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
static_assert(is_same_v<AComputeType, BComputeType>,
"Error: AComputeType and BComputeType should be the same");
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<AComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
@@ -241,7 +242,9 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_clamp_xdl_large_tensor_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<AComputeType, float>)
{
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
@@ -258,8 +261,8 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
add_device_grouped_conv3d_fwd_clamp_xdl_ndhwgc_gkzyxc_ndhwgk_f32_mem_inter_instances(
op_ptrs);
}
}
#endif
}
}
#endif // CK_USE_XDL

View File

@@ -68,7 +68,9 @@ void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
PassThrough,
PassThrough,
Scale>>>& instances);
#endif
#ifdef CK_ENABLE_TF32
void add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
NDHWGC,
@@ -149,22 +151,24 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
is_same_v<WeiLayout, GKZYXC> && is_same_v<OutLayout, NDHWGK> &&
DLayouts::Size() == 0)
{
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<InDataType, float> && is_same_v<WeiDataType, float> &&
is_same_v<OutDataType, float>)
{
#ifdef CK_ENABLE_TF32
if constexpr(is_same_v<ComputeType, TF32>)
{
add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_tf32_instances(
op_ptrs);
}
else
#endif
#ifdef CK_ENABLE_FP32
if constexpr(is_same_v<ComputeType, float>)
{
add_device_grouped_conv3d_fwd_xdl_scale_ndhwgc_gkzyxc_ndhwgk_f32_instances(
op_ptrs);
}
}
#endif
}
#ifdef CK_ENABLE_FP16
if constexpr(is_same_v<InDataType, half_t> && is_same_v<WeiDataType, half_t> &&
is_same_v<OutDataType, half_t> && is_same_v<ComputeType, half_t>)

View File

@@ -13,6 +13,8 @@ function(add_instance_library INSTANCE_NAME)
set(type1 "_f16")
elseif(type MATCHES "fp32")
set(type1 "_f32")
elseif(type MATCHES "tf32")
set(type1 "_tf32")
elseif(type MATCHES "fp8")
set(type1 "_f8")
elseif(type MATCHES "bf16")
@@ -27,8 +29,8 @@ function(add_instance_library INSTANCE_NAME)
#if filename matches any selected type, exit type loop and do no exclude the file from the list
set(test 0)
break()
elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR
source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND
elseif((source_name MATCHES "fp8" OR source_name MATCHES "fp32" OR source_name MATCHES "tf32" OR source_name MATCHES "fp64" OR source_name MATCHES "bf16" OR source_name MATCHES "int8" OR source_name MATCHES "fp16" OR
source_name MATCHES "_f8" OR source_name MATCHES "_f32" OR source_name MATCHES "_tf32" OR source_name MATCHES "_f64" OR source_name MATCHES "_i8" OR source_name MATCHES "_f16" OR source_name MATCHES "_b16") AND
NOT (source_name MATCHES type OR source_name MATCHES type1))
#if filename contains a type which doesn't match any selected type, mark it for removal
set(test 1)
@@ -102,9 +104,11 @@ function(add_instance_library INSTANCE_NAME)
list(REMOVE_ITEM ARGN "${source}")
endif()
# Only build tf32 instances for gfx942 & gfx950
if(NOT (INST_TARGETS MATCHES "gfx942|gfx950") AND source_name MATCHES "_tf32_")
message(DEBUG "removing tf32 instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
if(source_name MATCHES "_tf32_")
if(NOT ((INST_TARGETS MATCHES "gfx942|gfx950") AND CK_ENABLE_TF32))
message(DEBUG "removing tf32 instance ${source} ")
list(REMOVE_ITEM ARGN "${source}")
endif()
endif()
endforeach()
@@ -223,6 +227,10 @@ FOREACH(subdir_path ${dir_list})
message(DEBUG "fp32 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_tf32" OR "${cmake_instance}" MATCHES "_tf32") AND DTYPES MATCHES "tf32")
message(DEBUG "tf32 instance found!")
set(add_inst 1)
endif()
if(("${cmake_instance}" MATCHES "_fp64" OR "${cmake_instance}" MATCHES "_f64") AND DTYPES MATCHES "fp64")
message(DEBUG "fp64 instance found!")
set(add_inst 1)
@@ -237,6 +245,7 @@ FOREACH(subdir_path ${dir_list})
"${cmake_instance}" MATCHES "_f16" OR
"${cmake_instance}" MATCHES "_fp32" OR
"${cmake_instance}" MATCHES "_f32" OR
"${cmake_instance}" MATCHES "_tf32" OR
"${cmake_instance}" MATCHES "_fp64" OR
"${cmake_instance}" MATCHES "_f64" OR
"${cmake_instance}" MATCHES "_bf16" OR
@@ -330,7 +339,7 @@ FOREACH(subdir_path ${dir_list})
list(APPEND CK_DEVICE_OTHER_INSTANCES $<TARGET_OBJECTS:device_${target_dir}_instance>)
endif()
message(DEBUG "add_instance_directory ${subdir_path}")
endif()
endif()
else()
message(DEBUG "skip_instance_directory ${subdir_path}")
endif()

View File

@@ -84,9 +84,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
using F32 = float;
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
#if defined(__gfx942__)
using TF32 = ck::tf32_t;
#endif
using namespace ck::tensor_layout::convolution;
@@ -143,9 +141,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, GNHWK{}, GKYXC{}, GNHWC{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -164,9 +160,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NHWGK{}, GKYXC{}, NHWGC{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -185,9 +179,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NGKHW{}, GKYXC{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NGCHW_GKCYX_NGKHW)
@@ -206,9 +198,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NGKHW{}, GKCYX{}, NGCHW{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
}
@@ -230,9 +220,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, GNDHWK{}, GKZYXC{}, GNDHWC{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -251,9 +239,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NDHWGK{}, GKZYXC{}, NDHWGC{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -272,9 +258,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NGKDHW{}, GKZYXC{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
else if(layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -293,9 +277,7 @@ int profile_grouped_conv_bwd_data(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NGKDHW{}, GKCZYX{}, NGCDHW{}, F32{}, F32{}, F32{}, TF32{});
#endif
}
}
}

View File

@@ -99,9 +99,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
using BF16 = ck::bhalf_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
#if defined(__gfx942__)
using TF32 = ck::tf32_t;
#endif
using namespace ck::tensor_layout::convolution;
@@ -162,9 +160,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -184,9 +180,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -210,9 +204,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -243,9 +235,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -270,9 +260,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -306,9 +294,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -340,9 +326,7 @@ int profile_grouped_conv_bwd_weight(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}

View File

@@ -105,9 +105,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using INT8 = int8_t;
using F8 = ck::f8_t;
using BF8 = ck::bf8_t;
#if defined(__gfx942__) || defined(__gfx950__)
using TF32 = ck::tf32_t;
#endif
//
using GNWC = ck::tensor_layout::convolution::GNWC;
@@ -228,9 +226,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I1, GNWC{}, GKXC{}, GNWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -253,9 +249,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, GNHWC{}, GKYXC{}, GNHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::GNHWC_GKYXC_GNHWK)
@@ -280,9 +274,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I3, GNDHWC{}, GKZYXC{}, GNDHWK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
// NHWGC_GKYXC_NHWGK
@@ -306,9 +298,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I1, NWGC{}, GKXC{}, NWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -331,9 +321,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKYXC_NGKHW)
@@ -352,9 +340,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
@@ -373,9 +359,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -416,9 +400,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
// NGCDHW_GKCZYX_NGKDHW
@@ -439,9 +421,7 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__) || defined(__gfx950__)
return profile(I3, NGCDHW{}, GKCZYX{}, NGKDHW{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}

View File

@@ -105,9 +105,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
using F32 = float;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
#if defined(__gfx942__)
using TF32 = ck::tf32_t;
#endif
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
@@ -172,9 +170,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -194,9 +190,7 @@ int grouped_conv_fwd_bias_clamp(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}

View File

@@ -105,9 +105,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
using F32 = float;
using BF16 = ck::bhalf_t;
using F16 = ck::half_t;
#if defined(__gfx942__)
using TF32 = ck::tf32_t;
#endif
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
using NDHWGC = ck::tensor_layout::convolution::NDHWGC;
@@ -175,9 +173,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I2, NHWGC{}, GKYXC{}, NHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
@@ -197,9 +193,7 @@ int grouped_conv_fwd_clamp(int argc, char* argv[])
}
else if(data_type == ConvDataType::F32_F32_F32_TF32)
{
#if defined(__gfx942__)
return profile(I3, NDHWGC{}, GKZYXC{}, NDHWGK{}, F32{}, F32{}, F32{}, TF32{}, TF32{});
#endif
}
}

View File

@@ -65,6 +65,9 @@ function(add_test_executable TEST_NAME)
if((source_name MATCHES "_fp32|_f32") AND NOT "fp32" IN_LIST DTYPES)
set(test 1)
endif()
if((source_name MATCHES "_tf32|_tf32") AND NOT "tf32" IN_LIST DTYPES)
set(test 1)
endif()
if((source_name MATCHES "_fp64|_f64") AND NOT "fp64" IN_LIST DTYPES)
set(test 1)
endif()
@@ -156,6 +159,9 @@ function(add_gtest_executable TEST_NAME)
if((source_name MATCHES "_fp32|_f32") AND NOT "fp32" IN_LIST DTYPES)
set(test 1)
endif()
if((source_name MATCHES "_tf32|_tf32") AND NOT "tf32" IN_LIST DTYPES)
set(test 1)
endif()
if((source_name MATCHES "_fp64|_f64") AND NOT "fp64" IN_LIST DTYPES)
set(test 1)
endif()