diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp index b39ead09f3..d45186ece6 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm.hpp @@ -1,94 +1,47 @@ #pragma once #include +#include namespace ck_tile::builder { -enum class GemmImplementationType -{ - XDL, - WMMA, - DL +// Convenience struct for a tuple of m, n, and k values. +template +struct MNK { + T m{}; + T n{}; + T k{}; }; -enum class ConvolutionDirection -{ - Forward, - BackwardData, - BackwardWeight + +// Concept for thread block info for a GEMM problem. +template +concept ThreadBlockInfo = requires(T t) { + { t.block_size } -> std::convertible_to; + { t.sub_matrix.m } -> std::convertible_to; + { t.sub_matrix.n } -> std::convertible_to; + { t.sub_matrix.k } -> std::convertible_to; }; -enum class UniversalGemmSupport -{ - Supported, - NotSupported + +// Describe a thread block for a GEMM. +struct ThreadBlock { + // Thread block size. + int block_size; + // Size of the submatrix problem in a thread block. + MNK sub_matrix; }; -enum class SplitKSupport -{ - Supported, - SupportedTwoStage, - NotSupported +static_assert(ThreadBlockInfo); + +// Concept to check if struct provides thread block info. +template +concept HasThreadBlockInfo = requires { + { T::THREAD_BLOCK } -> ThreadBlockInfo; }; -enum class DepthwiseOptimization -{ - X16, - X8, - X4, - X2, - NotSupported -}; - -enum class LargeTensorSupport -{ - Supported, - SplitBatch, - NotSupported -}; - -enum class ImplementationType -{ - ExplicitDefault, - ExplicitMPadding, - ExplicitNPadding, - ExplicitKPadding, - ExplicitMNPadding, - ExplicitMKPadding, - ExplicitNKPadding, - ExplicitMNKPadding, - Implicit -}; - -enum class GemmPipelineVersion -{ - Naive, - ComputeFriendly, - MemFriendly, - ComputeFriendlyDoubleLDS, - ComputeFriendlyDoubleGlobalPrefetch -}; - -enum class GemmPipelineScheduler -{ - Intrawave, - Interwave -}; - -enum class ConvolutionSpecialization -{ - Default, - Filter1x1Pad0, - Filter1x1Stride1Pad0, - Filter3x3 -}; - -enum class MFMAInstructionSize -{ - M16N16, - M32N32 -}; +// No requirements yet for a ConvAlogorithm concept. template concept ConvAlgorithm = std::is_class_v; diff --git a/experimental/builder/include/ck_tile/builder/conv_factory.hpp b/experimental/builder/include/ck_tile/builder/conv_factory.hpp index ebdf764b38..72011fce9f 100644 --- a/experimental/builder/include/ck_tile/builder/conv_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_factory.hpp @@ -1,6 +1,7 @@ #pragma once -// #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" +// #include +// "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp" #include #include #include @@ -65,15 +66,6 @@ struct ConvSpec ck::tensor_operation::device::GemmSpecialization::MNKPadding; }; -// Store M,N,K values. -template -struct MNK -{ - T m = 0; - T n = 0; - T k = 0; -}; - // Block info for a convlution. struct ConvBlock { @@ -81,6 +73,23 @@ struct ConvBlock MNK per_block; }; +template +constexpr ConvBlock SetThreadBlockInfo() +{ + if constexpr(HasThreadBlockInfo) + { + constexpr auto& TB = Algo::THREAD_BLOCK; + return ConvBlock{ + .block_size = TB.block_size, + .per_block = {.m = TB.sub_matrix.m, .n = TB.sub_matrix.n, .k = TB.sub_matrix.k}}; + } + // Default values if thread block info isn't specified. + return ConvBlock{ + .block_size = 256, + .per_block = {.m = 256, .n = 256, .k = 32}, + }; +} + // Convolution tuning parameters. struct ConvTuning { @@ -119,17 +128,14 @@ template struct GroupedConvForwardXldCShuffleFactoryV3 { static constexpr int SPATIAL_DIM = Signature::SPATIAL_DIM; - using Layouts = ConvTensorLayouts; - using Types = ConvTensorTypes; - using Ops = ConvPassThroughOps; + using Layouts = ConvTensorLayouts; + using Types = ConvTensorTypes; + using Ops = ConvPassThroughOps; static constexpr ConvSpec SPECIALIZATION{ .conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default, .gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding, }; - static constexpr ConvBlock BLOCK{ - .block_size = 256, - .per_block = {.m = 256, .n = 256, .k = 32}, - }; + static constexpr ConvBlock BLOCK = SetThreadBlockInfo(); static constexpr ConvTuning TUNING{ .ak1 = 8, .ak2 = 8, @@ -165,53 +171,54 @@ struct GroupedConvForwardXldCShuffleFactoryV3 static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave; static constexpr auto PIPELINE_VERSION = ck::BlockGemmPipelineVersion::v4; // The convlution kernel class instance. - using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< // - SPATIAL_DIM, - typename Layouts::ALayout, - typename Layouts::BLayout, - typename Layouts::DsLayout, - typename Layouts::ELayout, - typename Types::ADataType, - typename Types::BDataType, - typename Types::AccDataType, - typename Types::CShuffleDataType, - typename Types::DsDataTypes, - typename Types::EDataTYpe, - typename Ops::AElementwiseOp, - typename Ops::BElementwiseOp, - typename Ops::CDEElementwiseOp, - SPECIALIZATION.conv_spec, - SPECIALIZATION.gemm_spec, - BLOCK.block_size, - BLOCK.per_block.m, - BLOCK.per_block.n, - BLOCK.per_block.k, - TUNING.ak1, - TUNING.ak2, - TUNING.m_per_xdl, - TUNING.n_per_dxl, - TUNING.m_xdl_per_wave, - TUNING.n_xdl_per_wave, - ToSequence, - ToSequence, - ToSequence, - A_BLOCK_TRANSFER.src_vector_dim, - A_BLOCK_TRANSFER.src_scaler_per_vector, - A_BLOCK_TRANSFER.dest_scaler_per_vector_k1, - A_BLOCK_TRANSFER.add_extra, - ToSequence, - ToSequence, - ToSequence, - B_BLOCK_TRANSFER.src_vector_dim, - B_BLOCK_TRANSFER.src_scaler_per_vector, - B_BLOCK_TRANSFER.dest_scaler_per_vector_k1, - B_BLOCK_TRANSFER.add_extra, - C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, - C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, - ToSequence, - C_BLOCK_TRANSFER.scaler_per_vector, - PIPELINE_SCHEDULER, - PIPELINE_VERSION>; + using Instance = + ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< // + SPATIAL_DIM, + typename Layouts::ALayout, + typename Layouts::BLayout, + typename Layouts::DsLayout, + typename Layouts::ELayout, + typename Types::ADataType, + typename Types::BDataType, + typename Types::AccDataType, + typename Types::CShuffleDataType, + typename Types::DsDataTypes, + typename Types::EDataTYpe, + typename Ops::AElementwiseOp, + typename Ops::BElementwiseOp, + typename Ops::CDEElementwiseOp, + SPECIALIZATION.conv_spec, + SPECIALIZATION.gemm_spec, + BLOCK.block_size, + BLOCK.per_block.m, + BLOCK.per_block.n, + BLOCK.per_block.k, + TUNING.ak1, + TUNING.ak2, + TUNING.m_per_xdl, + TUNING.n_per_dxl, + TUNING.m_xdl_per_wave, + TUNING.n_xdl_per_wave, + ToSequence, + ToSequence, + ToSequence, + A_BLOCK_TRANSFER.src_vector_dim, + A_BLOCK_TRANSFER.src_scaler_per_vector, + A_BLOCK_TRANSFER.dest_scaler_per_vector_k1, + A_BLOCK_TRANSFER.add_extra, + ToSequence, + ToSequence, + ToSequence, + B_BLOCK_TRANSFER.src_vector_dim, + B_BLOCK_TRANSFER.src_scaler_per_vector, + B_BLOCK_TRANSFER.dest_scaler_per_vector_k1, + B_BLOCK_TRANSFER.add_extra, + C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle, + C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle, + ToSequence, + C_BLOCK_TRANSFER.scaler_per_vector, + PIPELINE_SCHEDULER, + PIPELINE_VERSION>; }; } // namespace ck_tile::builder diff --git a/experimental/builder/test/test_conv_builder.cpp b/experimental/builder/test/test_conv_builder.cpp index d761b277ba..831e49a9d1 100644 --- a/experimental/builder/test/test_conv_builder.cpp +++ b/experimental/builder/test/test_conv_builder.cpp @@ -41,4 +41,23 @@ TEST(ConvBuilderTest, TestInstance) "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, " "8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>"); } + +struct ConvFwdXdlBf16CompInstances2xAlgorithm0 +{ + static constexpr ckb::ThreadBlock THREAD_BLOCK{ + .block_size = 256, + .sub_matrix = {.m = 256, .n = 256, .k = 32}, + }; +}; + +TEST(ConvBuilderTest, TestInstance0) +{ + using Builder = + ckb::ConvBuilder; + EXPECT_EQ( + Builder::Instance::TypeString(), + "DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, " + "8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>"); +} + } // namespace