mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Separate types from concepts.
This commit is contained in:
@@ -1,38 +1,37 @@
|
||||
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
#include "../utils/types.hpp"
|
||||
#include "ck_tile/builder/conv_signature_types.hpp"
|
||||
|
||||
int main()
|
||||
{
|
||||
namespace ckb = ck_tile::builder;
|
||||
namespace ckb_examples = ck_tile::builder::examples;
|
||||
using namespace ck_tile::builder;
|
||||
|
||||
constexpr ckb_examples::ConvSignature FwdConvSignature
|
||||
constexpr ConvSignature FwdConvSignature
|
||||
{
|
||||
.spatial_dim = 2,
|
||||
.direction = ckb::ConvDirection::FORWARD,
|
||||
.layout = ckb::GroupConvLayout::CHANNELS_LAST,
|
||||
.data_type = ckb::DataType::BF16,
|
||||
.direction = ConvDirection::FORWARD,
|
||||
.layout = GroupConvLayout::CHANNELS_LAST,
|
||||
.data_type = DataType::BF16
|
||||
};
|
||||
static_assert(ckb::ValidConvSignature<FwdConvSignature>);
|
||||
static_assert(ValidConvSignature<FwdConvSignature>);
|
||||
|
||||
// To get valid configuration parameters, refer to "device_grouped_conv_fwd_xdl_comp_instance.hpp".
|
||||
// This file contains the current instances of the kernel DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3.
|
||||
// Currently the build has this kernel hard-coded.
|
||||
// In the future, we may need builders per kernel type since they typically have slightly different parameters.
|
||||
|
||||
constexpr ckb::ThreadBlock FwdThreadBlock
|
||||
constexpr ThreadBlock FwdThreadBlock
|
||||
{
|
||||
.block_size = 256,
|
||||
.submatrix = {.m = 256, .n = 256, .k = 32} // Tile sizes
|
||||
};
|
||||
|
||||
constexpr ckb::ConvTuningParams FwdTuningParams
|
||||
constexpr ConvTuningParams FwdTuningParams
|
||||
{
|
||||
.ak1 = 8, .bk1 = 8, .m_per_xdl=32, .n_per_xdl = 32, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4
|
||||
};
|
||||
|
||||
constexpr ckb_examples::BlockTransfer FwdBlockTransfer
|
||||
constexpr BlockTransfer FwdBlockTransfer
|
||||
{
|
||||
.thread_cluster_dims_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_dims_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
@@ -50,15 +49,15 @@ int main()
|
||||
.b_source_access_order = {1, 0, 2}
|
||||
};
|
||||
|
||||
constexpr ckb_examples::ConvAlgorithm FwdConvAlgorithm
|
||||
constexpr ConvAlgorithm FwdConvAlgorithm
|
||||
{
|
||||
.thread_block = FwdThreadBlock,
|
||||
.tuning_params = FwdTuningParams,
|
||||
.block_transfer = FwdBlockTransfer,
|
||||
.pipeline_version = ckb::BlockGemmPipelineVersion::V4,
|
||||
.pipeline_version = BlockGemmPipelineVersion::V4,
|
||||
};
|
||||
|
||||
using Builder = ckb::ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
const auto kernel_string = Builder::Instance::TypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/builder/conv_algorithm.hpp"
|
||||
#include "ck_tile/builder/conv_signature.hpp"
|
||||
|
||||
namespace ck_tile::builder::examples
|
||||
{
|
||||
namespace ckb = ck_tile::builder;
|
||||
|
||||
struct ConvSignature {
|
||||
int spatial_dim;
|
||||
ckb::ConvDirection direction;
|
||||
ckb::GroupConvLayout layout;
|
||||
ckb::DataType data_type;
|
||||
};
|
||||
static_assert(ckb::ConvSignatureDescriptor<ConvSignature>);
|
||||
|
||||
struct BlockTransfer
|
||||
{
|
||||
ckb::BlockATransferLengths thread_cluster_dims_a;
|
||||
ckb::BlockBTransferLengths thread_cluster_dims_b;
|
||||
ckb::BlockCTransferLengths thread_cluster_dims_c;
|
||||
ckb::VectorTransferAB vector_transfer_a;
|
||||
ckb::VectorTransferAB vector_transfer_b;
|
||||
ckb::VectorTransferC vector_transfer_c;
|
||||
ckb::ThreadClusterAccessOrder a_thread_cluster_access_order;
|
||||
ckb::ThreadClusterAccessOrder b_thread_cluster_access_order;
|
||||
ckb::SourceAccessOrder a_source_access_order;
|
||||
ckb::SourceAccessOrder b_source_access_order;
|
||||
};
|
||||
|
||||
struct ConvAlgorithm
|
||||
{
|
||||
ckb::ThreadBlock thread_block;
|
||||
ckb::ConvTuningParams tuning_params;
|
||||
BlockTransfer block_transfer;
|
||||
ckb::BlockGemmPipelineVersion pipeline_version;
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesConvTuning<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesGemmPipelineVersion<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockATransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockBTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockCTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockAVectorTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockBVectorTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockCVectorTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesAThreadClusterAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBThreadClusterAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesASourceAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBSourceAccessOrder<ConvAlgorithm>);
|
||||
} // namespace ck_tile::builder::examples
|
||||
@@ -67,4 +67,41 @@ struct UnsupportedEnumValue
|
||||
{
|
||||
};
|
||||
|
||||
// Helper functions to convert enums to strings
|
||||
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
|
||||
{
|
||||
switch(dir)
|
||||
{
|
||||
case ConvDirection::FORWARD: return "Forward";
|
||||
case ConvDirection::BACKWARD_DATA: return "Backward Data";
|
||||
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view DataTypeToString(DataType dt)
|
||||
{
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::FP16: return "FP16";
|
||||
case DataType::FP32: return "FP32";
|
||||
case DataType::FP64: return "FP64";
|
||||
case DataType::BF16: return "BF16";
|
||||
case DataType::S16: return "S16";
|
||||
case DataType::S8: return "S8";
|
||||
case DataType::S4: return "S4";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout::CHANNELS_FIRST: return "Channels-first (NCHW)";
|
||||
case GroupConvLayout::CHANNELS_LAST: return "Channels-last (NHWC)";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -3,23 +3,10 @@
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
#include <array>
|
||||
#include "types.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// TODO: VP (Oct 3, 2025) - Separate the concepts and structs into separate files.
|
||||
// Concepts the define interface and structs are PODs that implement the concepts.
|
||||
// The interface is really just the concepts. Clients can define their own structs
|
||||
// as long as they satisfy the concepts.
|
||||
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
template <typename T>
|
||||
struct MNK
|
||||
{
|
||||
T m{};
|
||||
T n{};
|
||||
T k{};
|
||||
};
|
||||
|
||||
// Concept for thread block dimensions for a GEMM problem.
|
||||
template <typename T>
|
||||
concept ThreadBlockDescriptor = requires(T t) {
|
||||
@@ -29,16 +16,6 @@ concept ThreadBlockDescriptor = requires(T t) {
|
||||
{ t.submatrix.k } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
// Specifiy thread block dimensions for a GEMM.
|
||||
struct ThreadBlock
|
||||
{
|
||||
// Thread block size.
|
||||
int block_size;
|
||||
// Size of the submatrix problem in a thread block.
|
||||
MNK<int> submatrix;
|
||||
};
|
||||
static_assert(ThreadBlockDescriptor<ThreadBlock>);
|
||||
|
||||
// Concept to check if struct specifies thread block info.
|
||||
template <typename T>
|
||||
concept SpecifiesThreadBlock = requires {
|
||||
@@ -56,19 +33,6 @@ concept ConvTuningDescriptor = requires(T t) {
|
||||
{ t.n_xdl_per_wave } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
// Describe some convolution tuning parameters.
|
||||
struct ConvTuningParams
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
int ak1 = 0;
|
||||
int bk1 = 0;
|
||||
int m_per_xdl = 0;
|
||||
int n_per_xdl = 0;
|
||||
int m_xdl_per_wave = 0;
|
||||
int n_xdl_per_wave = 0;
|
||||
};
|
||||
static_assert(ConvTuningDescriptor<ConvTuningParams>);
|
||||
|
||||
// Concept to check if a struct specifies convolution tuning info.
|
||||
template <typename T>
|
||||
concept SpecifiesConvTuning = requires {
|
||||
@@ -83,15 +47,6 @@ concept BlockATransferDescriptor = requires(T t) {
|
||||
{ t.k1 } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
// Describe A block transfer thread cluster lengths.
|
||||
struct BlockATransferLengths
|
||||
{
|
||||
int k0;
|
||||
int m;
|
||||
int k1;
|
||||
};
|
||||
static_assert(BlockATransferDescriptor<BlockATransferLengths>);
|
||||
|
||||
// Concept for B block transfer thread cluster lengths.
|
||||
template <typename T>
|
||||
concept BlockBTransferDescriptor = requires(T t) {
|
||||
@@ -100,43 +55,18 @@ concept BlockBTransferDescriptor = requires(T t) {
|
||||
{ t.k1 } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
// Describe B block transfer thread cluster lengths.
|
||||
struct BlockBTransferLengths
|
||||
{
|
||||
int k0;
|
||||
int n;
|
||||
int k1;
|
||||
};
|
||||
static_assert(BlockBTransferDescriptor<BlockBTransferLengths>);
|
||||
|
||||
// Concept for the thread cluster access order
|
||||
template <typename T>
|
||||
concept ThreadClusterAccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<int, 3>>;
|
||||
};
|
||||
|
||||
// Describe the thread cluster access order.
|
||||
struct ThreadClusterAccessOrder
|
||||
{
|
||||
// Order of the cluster dimensions. Must be a permutation of {0, 1, 2}.
|
||||
std::array<int, 3> order;
|
||||
};
|
||||
static_assert(ThreadClusterAccessOrderDescriptor<ThreadClusterAccessOrder>);
|
||||
|
||||
// Concept to describe source access order
|
||||
template <typename T>
|
||||
concept SourceAccessOrderDescriptor = requires(T t) {
|
||||
{ t.order } -> std::convertible_to<std::array<int, 3>>;
|
||||
};
|
||||
|
||||
// Describe the source access order.
|
||||
struct SourceAccessOrder
|
||||
{
|
||||
// Order of the source dimensions. Must be a permutation of {0, 1, 2}.
|
||||
std::array<int, 3> order;
|
||||
};
|
||||
static_assert(SourceAccessOrderDescriptor<SourceAccessOrder>);
|
||||
|
||||
// Concept for C block transfer thread cluster lengths.
|
||||
template <typename T>
|
||||
concept BlockCTransferDescriptor = requires(T t) {
|
||||
@@ -146,16 +76,6 @@ concept BlockCTransferDescriptor = requires(T t) {
|
||||
{ t.n_wave_per_xdl } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
// Describe C block transfer thread cluster lengths.
|
||||
struct BlockCTransferLengths
|
||||
{
|
||||
int m_block;
|
||||
int m_wave_per_xdl;
|
||||
int n_block;
|
||||
int n_wave_per_xdl;
|
||||
};
|
||||
static_assert(BlockCTransferDescriptor<BlockCTransferLengths>);
|
||||
|
||||
// Concept for vector transfer details for A and B tensors
|
||||
template <typename T>
|
||||
concept VectorTransferDescriptorAB = requires(T t) {
|
||||
@@ -165,15 +85,6 @@ concept VectorTransferDescriptorAB = requires(T t) {
|
||||
{ t.add_extra } -> std::convertible_to<bool>;
|
||||
};
|
||||
|
||||
struct VectorTransferAB
|
||||
{
|
||||
size_t src_vector_dim;
|
||||
size_t src_scaler_per_vector;
|
||||
size_t dest_scaler_per_vector_k1;
|
||||
bool add_extra;
|
||||
};
|
||||
static_assert(VectorTransferDescriptorAB<VectorTransferAB>);
|
||||
|
||||
// Concept for the C tensor vectors transfer details.
|
||||
template <typename T>
|
||||
concept VectorTransferDescriptorC = requires(T t) {
|
||||
@@ -182,14 +93,6 @@ concept VectorTransferDescriptorC = requires(T t) {
|
||||
{ t.scaler_per_vector } -> std::convertible_to<size_t>;
|
||||
};
|
||||
|
||||
struct VectorTransferC
|
||||
{
|
||||
size_t m_xdl_per_wave_per_shuffle;
|
||||
size_t n_xdl_per_wave_per_shuffle;
|
||||
size_t scaler_per_vector;
|
||||
};
|
||||
static_assert(VectorTransferDescriptorC<VectorTransferC>);
|
||||
|
||||
// Concept to check if a struct specifies A Block tranfer info.
|
||||
template <typename T>
|
||||
concept SpecifiesBlockATransfer = requires(T t) {
|
||||
@@ -250,15 +153,6 @@ concept SpecifiesBSourceAccessOrder = requires(T t) {
|
||||
{ T::block_transfer.b_source_access_order } -> SourceAccessOrderDescriptor;
|
||||
};
|
||||
|
||||
// Enums for the current block GEMM pipeline versions.
|
||||
enum class BlockGemmPipelineVersion
|
||||
{
|
||||
V1,
|
||||
V3,
|
||||
V4,
|
||||
V5
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block_gemm_pipeline_version.
|
||||
template <typename T>
|
||||
concept SpecifiesGemmPipelineVersion = requires {
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
#pragma once
|
||||
|
||||
#include "conv_algorithm_concepts.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
template <typename T>
|
||||
struct MNK
|
||||
{
|
||||
T m{};
|
||||
T n{};
|
||||
T k{};
|
||||
};
|
||||
|
||||
// Specifiy thread block dimensions for a GEMM.
|
||||
struct ThreadBlock
|
||||
{
|
||||
// Thread block size.
|
||||
int block_size;
|
||||
// Size of the submatrix problem in a thread block.
|
||||
MNK<int> submatrix;
|
||||
};
|
||||
static_assert(ThreadBlockDescriptor<ThreadBlock>);
|
||||
|
||||
// Describe some convolution tuning parameters.
|
||||
struct ConvTuningParams
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
int ak1 = 0;
|
||||
int bk1 = 0;
|
||||
int m_per_xdl = 0;
|
||||
int n_per_xdl = 0;
|
||||
int m_xdl_per_wave = 0;
|
||||
int n_xdl_per_wave = 0;
|
||||
};
|
||||
static_assert(ConvTuningDescriptor<ConvTuningParams>);
|
||||
|
||||
// Describe A block transfer thread cluster lengths.
|
||||
struct BlockATransferLengths
|
||||
{
|
||||
int k0;
|
||||
int m;
|
||||
int k1;
|
||||
};
|
||||
static_assert(BlockATransferDescriptor<BlockATransferLengths>);
|
||||
|
||||
// Describe B block transfer thread cluster lengths.
|
||||
struct BlockBTransferLengths
|
||||
{
|
||||
int k0;
|
||||
int n;
|
||||
int k1;
|
||||
};
|
||||
static_assert(BlockBTransferDescriptor<BlockBTransferLengths>);
|
||||
|
||||
// Describe the thread cluster access order.
|
||||
struct ThreadClusterAccessOrder
|
||||
{
|
||||
// Order of the cluster dimensions. Must be a permutation of {0, 1, 2}.
|
||||
std::array<int, 3> order;
|
||||
};
|
||||
static_assert(ThreadClusterAccessOrderDescriptor<ThreadClusterAccessOrder>);
|
||||
|
||||
// Describe the source access order.
|
||||
struct SourceAccessOrder
|
||||
{
|
||||
// Order of the source dimensions. Must be a permutation of {0, 1, 2}.
|
||||
std::array<int, 3> order;
|
||||
};
|
||||
static_assert(SourceAccessOrderDescriptor<SourceAccessOrder>);
|
||||
|
||||
// Describe C block transfer thread cluster lengths.
|
||||
struct BlockCTransferLengths
|
||||
{
|
||||
int m_block;
|
||||
int m_wave_per_xdl;
|
||||
int n_block;
|
||||
int n_wave_per_xdl;
|
||||
};
|
||||
static_assert(BlockCTransferDescriptor<BlockCTransferLengths>);
|
||||
|
||||
struct VectorTransferAB
|
||||
{
|
||||
size_t src_vector_dim;
|
||||
size_t src_scaler_per_vector;
|
||||
size_t dest_scaler_per_vector_k1;
|
||||
bool add_extra;
|
||||
};
|
||||
static_assert(VectorTransferDescriptorAB<VectorTransferAB>);
|
||||
|
||||
struct VectorTransferC
|
||||
{
|
||||
size_t m_xdl_per_wave_per_shuffle;
|
||||
size_t n_xdl_per_wave_per_shuffle;
|
||||
size_t scaler_per_vector;
|
||||
};
|
||||
static_assert(VectorTransferDescriptorC<VectorTransferC>);
|
||||
|
||||
struct BlockTransfer
|
||||
{
|
||||
BlockATransferLengths thread_cluster_dims_a;
|
||||
BlockBTransferLengths thread_cluster_dims_b;
|
||||
BlockCTransferLengths thread_cluster_dims_c;
|
||||
VectorTransferAB vector_transfer_a;
|
||||
VectorTransferAB vector_transfer_b;
|
||||
VectorTransferC vector_transfer_c;
|
||||
ThreadClusterAccessOrder a_thread_cluster_access_order;
|
||||
ThreadClusterAccessOrder b_thread_cluster_access_order;
|
||||
SourceAccessOrder a_source_access_order;
|
||||
SourceAccessOrder b_source_access_order;
|
||||
};
|
||||
|
||||
struct ConvAlgorithm
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
ConvTuningParams tuning_params;
|
||||
BlockTransfer block_transfer;
|
||||
BlockGemmPipelineVersion pipeline_version;
|
||||
};
|
||||
static_assert(ConvAlgorithmDescriptor<ConvAlgorithm>);
|
||||
static_assert(SpecifiesThreadBlock<ConvAlgorithm>);
|
||||
static_assert(SpecifiesConvTuning<ConvAlgorithm>);
|
||||
static_assert(SpecifiesGemmPipelineVersion<ConvAlgorithm>);
|
||||
static_assert(SpecifiesBlockATransfer<ConvAlgorithm>);
|
||||
static_assert(SpecifiesBlockBTransfer<ConvAlgorithm>);
|
||||
static_assert(SpecifiesBlockCTransfer<ConvAlgorithm>);
|
||||
static_assert(SpecifiesBlockAVectorTransfer<ConvAlgorithm>);
|
||||
static_assert(SpecifiesBlockBVectorTransfer<ConvAlgorithm>);
|
||||
static_assert(SpecifiesBlockCVectorTransfer<ConvAlgorithm>);
|
||||
static_assert(SpecifiesAThreadClusterAccessOrder<ConvAlgorithm>);
|
||||
static_assert(SpecifiesBThreadClusterAccessOrder<ConvAlgorithm>);
|
||||
static_assert(SpecifiesASourceAccessOrder<ConvAlgorithm>);
|
||||
static_assert(SpecifiesBSourceAccessOrder<ConvAlgorithm>);
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
@@ -3,9 +3,7 @@
|
||||
#include <concepts>
|
||||
#include <type_traits>
|
||||
|
||||
#include <ck_tile/builder/conv_algorithm.hpp>
|
||||
#include <ck_tile/builder/conv_factory.hpp>
|
||||
#include <ck_tile/builder/conv_signature.hpp>
|
||||
#include <ck_tile/builder/versions.hpp>
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
@@ -33,12 +33,12 @@
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp>
|
||||
#include <ck_tile/builder/conv_signature.hpp>
|
||||
#include <ck_tile/builder/conv_algorithm.hpp>
|
||||
#include <ck_tile/builder/conv_algorithm_types.hpp>
|
||||
#include <ck_tile/builder/builder_utils.hpp>
|
||||
#include <ck_tile/builder/types.hpp>
|
||||
#include <ck_tile/builder/versions.hpp>
|
||||
|
||||
namespace ck_tile::builder {
|
||||
namespace ck_tile::builder::factory_internal {
|
||||
|
||||
// Type mappings from the builder GroupConvLayout enum class to the CK tensor data types.
|
||||
template <GroupConvLayout Layout, int SPATIAL_DIM, ConvDirection DIR>
|
||||
@@ -374,6 +374,10 @@ constexpr ck::BlockGemmPipelineVersion SetBlockGemmPipelineVersion()
|
||||
return ck::BlockGemmPipelineVersion::v4;
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::factory
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Primary template for the convolution factory.
|
||||
template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
ConvAlgorithmDescriptor auto ALGORITHM,
|
||||
@@ -388,20 +392,20 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = ConvPassThroughOps;
|
||||
static constexpr ConvSpec SPECIALIZATION{
|
||||
using Layouts = factory_internal::ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::FORWARD>;
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ConvPassThroughOps;
|
||||
static constexpr factory_internal::ConvSpec SPECIALIZATION{
|
||||
.conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
};
|
||||
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr ConvTuning TUNING = SetConvTuningInfo<SIGNATURE, ALGORITHM>();
|
||||
static constexpr BlockTransfer A_BLOCK_TRANSFER = SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
static constexpr BlockTransfer B_BLOCK_TRANSFER = SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PIPELINE_VERSION = SetBlockGemmPipelineVersion<ALGORITHM>();
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto TUNING = factory_internal::SetConvTuningInfo<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto A_BLOCK_TRANSFER = factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
static constexpr auto B_BLOCK_TRANSFER = factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PIPELINE_VERSION = factory_internal::SetBlockGemmPipelineVersion<ALGORITHM>();
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
|
||||
@@ -461,20 +465,20 @@ template <ConvSignatureDescriptor auto SIGNATURE,
|
||||
struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
{
|
||||
static constexpr int SPATIAL_DIM = SIGNATURE.spatial_dim;
|
||||
using Layouts = ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::BACKWARD_DATA>;
|
||||
using Types = ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = ConvPassThroughOps;
|
||||
static constexpr ConvSpec SPECIALIZATION{
|
||||
using Layouts = factory_internal::ConvTensorLayouts<SIGNATURE.layout, SPATIAL_DIM, ConvDirection::BACKWARD_DATA>;
|
||||
using Types = factory_internal::ConvTensorTypes<SIGNATURE.data_type>;
|
||||
using Ops = factory_internal::ConvPassThroughOps;
|
||||
static constexpr factory_internal::ConvSpec SPECIALIZATION{
|
||||
.conv_spec = ck::tensor_operation::device::ConvolutionBackwardDataSpecialization::Default,
|
||||
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
};
|
||||
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr ConvTuning TUNING = SetConvTuningInfo<SIGNATURE, ALGORITHM>();
|
||||
static constexpr BlockTransfer A_BLOCK_TRANSFER = SetBwdDataConvABlockTransfer<ALGORITHM>();
|
||||
static constexpr BlockTransfer B_BLOCK_TRANSFER = SetBwdDataConvBBlockTransfer<ALGORITHM>();
|
||||
static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PIPELINE_VERSION = SetBlockGemmPipelineVersion<ALGORITHM>();
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto TUNING = factory_internal::SetConvTuningInfo<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto A_BLOCK_TRANSFER = factory_internal::SetBwdDataConvABlockTransfer<ALGORITHM>();
|
||||
static constexpr auto B_BLOCK_TRANSFER = factory_internal::SetBwdDataConvBBlockTransfer<ALGORITHM>();
|
||||
static constexpr auto C_BLOCK_TRANSFER = factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PIPELINE_VERSION = factory_internal::SetBlockGemmPipelineVersion<ALGORITHM>();
|
||||
// The backward-data convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1<
|
||||
|
||||
@@ -21,13 +21,6 @@
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
// Memory layouts for convolution tensors, following PyTorch conventions.
|
||||
enum class GroupConvLayout
|
||||
{
|
||||
CHANNELS_LAST, // e.g., NHWGC
|
||||
CHANNELS_FIRST // e.g., NGCHW
|
||||
};
|
||||
|
||||
// Constrains convolution to 1D, 2D, or 3D spatial dimensions.
|
||||
template <auto N>
|
||||
concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 || N == 3);
|
||||
@@ -36,25 +29,6 @@ concept ConvSpatialDim = std::is_integral_v<decltype(N)> && (N == 1 || N == 2 ||
|
||||
template <DataType T>
|
||||
concept ConvDataType = (T == DataType::FP32) || (T == DataType::FP16) || (T == DataType::BF16);
|
||||
|
||||
// Direction of the convolution operation.
|
||||
enum class ConvDirection
|
||||
{
|
||||
FORWARD,
|
||||
BACKWARD_DATA,
|
||||
BACKWARD_WEIGHT
|
||||
};
|
||||
|
||||
// Fused element-wise operations.
|
||||
enum class ElementwiseOperation
|
||||
{
|
||||
BIAS,
|
||||
BIAS_CLAMP,
|
||||
BILINEAR,
|
||||
CLAMP,
|
||||
SCALE,
|
||||
PASS_THROUGH
|
||||
};
|
||||
|
||||
// Concept for a type that defines a convolution's operational signature.
|
||||
template <typename T>
|
||||
concept ConvSignatureDescriptor = requires(T t) {
|
||||
@@ -83,41 +57,4 @@ concept ConvDirectionIsBackwardData = (Sig.direction == ConvDirection::BACKWARD_
|
||||
template <auto Sig>
|
||||
concept ConvDirectionIsBackwardWeight = (Sig.direction == ConvDirection::BACKWARD_WEIGHT);
|
||||
|
||||
// Helper functions to convert enums to strings
|
||||
constexpr std::string_view ConvDirectionToString(ConvDirection dir)
|
||||
{
|
||||
switch(dir)
|
||||
{
|
||||
case ConvDirection::FORWARD: return "Forward";
|
||||
case ConvDirection::BACKWARD_DATA: return "Backward Data";
|
||||
case ConvDirection::BACKWARD_WEIGHT: return "Backward Weight";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view DataTypeToString(DataType dt)
|
||||
{
|
||||
switch(dt)
|
||||
{
|
||||
case DataType::FP16: return "FP16";
|
||||
case DataType::FP32: return "FP32";
|
||||
case DataType::FP64: return "FP64";
|
||||
case DataType::BF16: return "BF16";
|
||||
case DataType::S16: return "S16";
|
||||
case DataType::S8: return "S8";
|
||||
case DataType::S4: return "S4";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
constexpr std::string_view LayoutToString(GroupConvLayout layout)
|
||||
{
|
||||
switch(layout)
|
||||
{
|
||||
case GroupConvLayout::CHANNELS_FIRST: return "Channels-first (NCHW)";
|
||||
case GroupConvLayout::CHANNELS_LAST: return "Channels-last (NHWC)";
|
||||
default: return "Unknown";
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -13,4 +13,39 @@ enum class DataType
|
||||
S4,
|
||||
};
|
||||
|
||||
// Memory layouts for convolution tensors, following PyTorch conventions.
|
||||
enum class GroupConvLayout
|
||||
{
|
||||
CHANNELS_LAST, // e.g., NHWGC
|
||||
CHANNELS_FIRST // e.g., NGCHW
|
||||
};
|
||||
|
||||
// Direction of the convolution operation.
|
||||
enum class ConvDirection
|
||||
{
|
||||
FORWARD,
|
||||
BACKWARD_DATA,
|
||||
BACKWARD_WEIGHT
|
||||
};
|
||||
|
||||
// Fused element-wise operations.
|
||||
enum class ElementwiseOperation
|
||||
{
|
||||
BIAS,
|
||||
BIAS_CLAMP,
|
||||
BILINEAR,
|
||||
CLAMP,
|
||||
SCALE,
|
||||
PASS_THROUGH
|
||||
};
|
||||
|
||||
// Enums for the current block GEMM pipeline versions.
|
||||
enum class BlockGemmPipelineVersion
|
||||
{
|
||||
V1,
|
||||
V3,
|
||||
V4,
|
||||
V5
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
Reference in New Issue
Block a user