mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Add thread block info to factory.
Now we can set the thread block size and submatrix shape for the builder.
This commit is contained in:
@@ -1,94 +1,47 @@
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
enum class GemmImplementationType
|
||||
{
|
||||
XDL,
|
||||
WMMA,
|
||||
DL
|
||||
// Convenience struct for a tuple of m, n, and k values.
|
||||
template <typename T>
|
||||
struct MNK {
|
||||
T m{};
|
||||
T n{};
|
||||
T k{};
|
||||
};
|
||||
|
||||
enum class ConvolutionDirection
|
||||
{
|
||||
Forward,
|
||||
BackwardData,
|
||||
BackwardWeight
|
||||
|
||||
// Concept for thread block info for a GEMM problem.
|
||||
template <typename T>
|
||||
concept ThreadBlockInfo = requires(T t) {
|
||||
{ t.block_size } -> std::convertible_to<int>;
|
||||
{ t.sub_matrix.m } -> std::convertible_to<int>;
|
||||
{ t.sub_matrix.n } -> std::convertible_to<int>;
|
||||
{ t.sub_matrix.k } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
enum class UniversalGemmSupport
|
||||
{
|
||||
Supported,
|
||||
NotSupported
|
||||
|
||||
// Describe a thread block for a GEMM.
|
||||
struct ThreadBlock {
|
||||
// Thread block size.
|
||||
int block_size;
|
||||
// Size of the submatrix problem in a thread block.
|
||||
MNK<int> sub_matrix;
|
||||
};
|
||||
|
||||
enum class SplitKSupport
|
||||
{
|
||||
Supported,
|
||||
SupportedTwoStage,
|
||||
NotSupported
|
||||
static_assert(ThreadBlockInfo<ThreadBlock>);
|
||||
|
||||
// Concept to check if struct provides thread block info.
|
||||
template <typename T>
|
||||
concept HasThreadBlockInfo = requires {
|
||||
{ T::THREAD_BLOCK } -> ThreadBlockInfo;
|
||||
};
|
||||
|
||||
enum class DepthwiseOptimization
|
||||
{
|
||||
X16,
|
||||
X8,
|
||||
X4,
|
||||
X2,
|
||||
NotSupported
|
||||
};
|
||||
|
||||
enum class LargeTensorSupport
|
||||
{
|
||||
Supported,
|
||||
SplitBatch,
|
||||
NotSupported
|
||||
};
|
||||
|
||||
enum class ImplementationType
|
||||
{
|
||||
ExplicitDefault,
|
||||
ExplicitMPadding,
|
||||
ExplicitNPadding,
|
||||
ExplicitKPadding,
|
||||
ExplicitMNPadding,
|
||||
ExplicitMKPadding,
|
||||
ExplicitNKPadding,
|
||||
ExplicitMNKPadding,
|
||||
Implicit
|
||||
};
|
||||
|
||||
enum class GemmPipelineVersion
|
||||
{
|
||||
Naive,
|
||||
ComputeFriendly,
|
||||
MemFriendly,
|
||||
ComputeFriendlyDoubleLDS,
|
||||
ComputeFriendlyDoubleGlobalPrefetch
|
||||
};
|
||||
|
||||
enum class GemmPipelineScheduler
|
||||
{
|
||||
Intrawave,
|
||||
Interwave
|
||||
};
|
||||
|
||||
enum class ConvolutionSpecialization
|
||||
{
|
||||
Default,
|
||||
Filter1x1Pad0,
|
||||
Filter1x1Stride1Pad0,
|
||||
Filter3x3
|
||||
};
|
||||
|
||||
enum class MFMAInstructionSize
|
||||
{
|
||||
M16N16,
|
||||
M32N32
|
||||
};
|
||||
|
||||
// No requirements yet for a ConvAlogorithm concept.
|
||||
template <typename T>
|
||||
concept ConvAlgorithm = std::is_class_v<T>;
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
// #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
// #include
|
||||
// "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
|
||||
#include <ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
|
||||
#include <ck_tile/builder/conv_signature.hpp>
|
||||
#include <ck_tile/builder/conv_algorithm.hpp>
|
||||
@@ -65,15 +66,6 @@ struct ConvSpec
|
||||
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
};
|
||||
|
||||
// Store M,N,K values.
|
||||
template <typename T>
|
||||
struct MNK
|
||||
{
|
||||
T m = 0;
|
||||
T n = 0;
|
||||
T k = 0;
|
||||
};
|
||||
|
||||
// Block info for a convlution.
|
||||
struct ConvBlock
|
||||
{
|
||||
@@ -81,6 +73,23 @@ struct ConvBlock
|
||||
MNK<int> per_block;
|
||||
};
|
||||
|
||||
template <ConvAlgorithm Algo>
|
||||
constexpr ConvBlock SetThreadBlockInfo()
|
||||
{
|
||||
if constexpr(HasThreadBlockInfo<Algo>)
|
||||
{
|
||||
constexpr auto& TB = Algo::THREAD_BLOCK;
|
||||
return ConvBlock{
|
||||
.block_size = TB.block_size,
|
||||
.per_block = {.m = TB.sub_matrix.m, .n = TB.sub_matrix.n, .k = TB.sub_matrix.k}};
|
||||
}
|
||||
// Default values if thread block info isn't specified.
|
||||
return ConvBlock{
|
||||
.block_size = 256,
|
||||
.per_block = {.m = 256, .n = 256, .k = 32},
|
||||
};
|
||||
}
|
||||
|
||||
// Convolution tuning parameters.
|
||||
struct ConvTuning
|
||||
{
|
||||
@@ -119,17 +128,14 @@ template <ConvSignature Signature, ConvAlgorithm Algorithm, auto Version>
|
||||
struct GroupedConvForwardXldCShuffleFactoryV3
|
||||
{
|
||||
static constexpr int SPATIAL_DIM = Signature::SPATIAL_DIM;
|
||||
using Layouts = ConvTensorLayouts<Signature::LAYOUT>;
|
||||
using Types = ConvTensorTypes<Signature::DATA_TYPE>;
|
||||
using Ops = ConvPassThroughOps;
|
||||
using Layouts = ConvTensorLayouts<Signature::LAYOUT>;
|
||||
using Types = ConvTensorTypes<Signature::DATA_TYPE>;
|
||||
using Ops = ConvPassThroughOps;
|
||||
static constexpr ConvSpec SPECIALIZATION{
|
||||
.conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
};
|
||||
static constexpr ConvBlock BLOCK{
|
||||
.block_size = 256,
|
||||
.per_block = {.m = 256, .n = 256, .k = 32},
|
||||
};
|
||||
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<Algorithm>();
|
||||
static constexpr ConvTuning TUNING{
|
||||
.ak1 = 8,
|
||||
.ak2 = 8,
|
||||
@@ -165,53 +171,54 @@ struct GroupedConvForwardXldCShuffleFactoryV3
|
||||
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PIPELINE_VERSION = ck::BlockGemmPipelineVersion::v4;
|
||||
// The convlution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataTYpe,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
TUNING.ak1,
|
||||
TUNING.ak2,
|
||||
TUNING.m_per_xdl,
|
||||
TUNING.n_per_dxl,
|
||||
TUNING.m_xdl_per_wave,
|
||||
TUNING.n_xdl_per_wave,
|
||||
ToSequence<A_BLOCK_TRANSFER.thread_cluster_lengths>,
|
||||
ToSequence<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
ToSequence<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scaler_per_vector,
|
||||
A_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
|
||||
A_BLOCK_TRANSFER.add_extra,
|
||||
ToSequence<B_BLOCK_TRANSFER.thread_cluster_lengths>,
|
||||
ToSequence<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
ToSequence<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scaler_per_vector,
|
||||
B_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
|
||||
B_BLOCK_TRANSFER.add_extra,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
ToSequence<C_BLOCK_TRANSFER.cluster_lengths>,
|
||||
C_BLOCK_TRANSFER.scaler_per_vector,
|
||||
PIPELINE_SCHEDULER,
|
||||
PIPELINE_VERSION>;
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
|
||||
SPATIAL_DIM,
|
||||
typename Layouts::ALayout,
|
||||
typename Layouts::BLayout,
|
||||
typename Layouts::DsLayout,
|
||||
typename Layouts::ELayout,
|
||||
typename Types::ADataType,
|
||||
typename Types::BDataType,
|
||||
typename Types::AccDataType,
|
||||
typename Types::CShuffleDataType,
|
||||
typename Types::DsDataTypes,
|
||||
typename Types::EDataTYpe,
|
||||
typename Ops::AElementwiseOp,
|
||||
typename Ops::BElementwiseOp,
|
||||
typename Ops::CDEElementwiseOp,
|
||||
SPECIALIZATION.conv_spec,
|
||||
SPECIALIZATION.gemm_spec,
|
||||
BLOCK.block_size,
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
TUNING.ak1,
|
||||
TUNING.ak2,
|
||||
TUNING.m_per_xdl,
|
||||
TUNING.n_per_dxl,
|
||||
TUNING.m_xdl_per_wave,
|
||||
TUNING.n_xdl_per_wave,
|
||||
ToSequence<A_BLOCK_TRANSFER.thread_cluster_lengths>,
|
||||
ToSequence<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
ToSequence<A_BLOCK_TRANSFER.src_access_order>,
|
||||
A_BLOCK_TRANSFER.src_vector_dim,
|
||||
A_BLOCK_TRANSFER.src_scaler_per_vector,
|
||||
A_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
|
||||
A_BLOCK_TRANSFER.add_extra,
|
||||
ToSequence<B_BLOCK_TRANSFER.thread_cluster_lengths>,
|
||||
ToSequence<B_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
ToSequence<B_BLOCK_TRANSFER.src_access_order>,
|
||||
B_BLOCK_TRANSFER.src_vector_dim,
|
||||
B_BLOCK_TRANSFER.src_scaler_per_vector,
|
||||
B_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
|
||||
B_BLOCK_TRANSFER.add_extra,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
ToSequence<C_BLOCK_TRANSFER.cluster_lengths>,
|
||||
C_BLOCK_TRANSFER.scaler_per_vector,
|
||||
PIPELINE_SCHEDULER,
|
||||
PIPELINE_VERSION>;
|
||||
};
|
||||
|
||||
} // namespace ck_tile::builder
|
||||
|
||||
@@ -41,4 +41,23 @@ TEST(ConvBuilderTest, TestInstance)
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, "
|
||||
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
|
||||
}
|
||||
|
||||
struct ConvFwdXdlBf16CompInstances2xAlgorithm0
|
||||
{
|
||||
static constexpr ckb::ThreadBlock THREAD_BLOCK{
|
||||
.block_size = 256,
|
||||
.sub_matrix = {.m = 256, .n = 256, .k = 32},
|
||||
};
|
||||
};
|
||||
|
||||
TEST(ConvBuilderTest, TestInstance0)
|
||||
{
|
||||
using Builder =
|
||||
ckb::ConvBuilder<FwdConvSignature, ConvFwdXdlBf16CompInstances2xAlgorithm0, API_VERSION>;
|
||||
EXPECT_EQ(
|
||||
Builder::Instance::TypeString(),
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, "
|
||||
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Reference in New Issue
Block a user