Add thread block info to factory.

Now we can set the thread block size and submatrix shape for the builder.
This commit is contained in:
John Shumway
2025-09-01 21:54:41 +00:00
parent cee90b800e
commit 061fb06eef
3 changed files with 119 additions and 140 deletions

View File

@@ -1,94 +1,47 @@
#pragma once
#include <type_traits>
#include <concepts>
namespace ck_tile::builder {
enum class GemmImplementationType
{
XDL,
WMMA,
DL
// Convenience struct for a tuple of m, n, and k values.
template <typename T>
struct MNK {
T m{};
T n{};
T k{};
};
enum class ConvolutionDirection
{
Forward,
BackwardData,
BackwardWeight
// Concept for thread block info for a GEMM problem.
template <typename T>
concept ThreadBlockInfo = requires(T t) {
{ t.block_size } -> std::convertible_to<int>;
{ t.sub_matrix.m } -> std::convertible_to<int>;
{ t.sub_matrix.n } -> std::convertible_to<int>;
{ t.sub_matrix.k } -> std::convertible_to<int>;
};
enum class UniversalGemmSupport
{
Supported,
NotSupported
// Describe a thread block for a GEMM.
struct ThreadBlock {
// Thread block size.
int block_size;
// Size of the submatrix problem in a thread block.
MNK<int> sub_matrix;
};
enum class SplitKSupport
{
Supported,
SupportedTwoStage,
NotSupported
static_assert(ThreadBlockInfo<ThreadBlock>);
// Concept to check if struct provides thread block info.
template <typename T>
concept HasThreadBlockInfo = requires {
{ T::THREAD_BLOCK } -> ThreadBlockInfo;
};
enum class DepthwiseOptimization
{
X16,
X8,
X4,
X2,
NotSupported
};
enum class LargeTensorSupport
{
Supported,
SplitBatch,
NotSupported
};
enum class ImplementationType
{
ExplicitDefault,
ExplicitMPadding,
ExplicitNPadding,
ExplicitKPadding,
ExplicitMNPadding,
ExplicitMKPadding,
ExplicitNKPadding,
ExplicitMNKPadding,
Implicit
};
enum class GemmPipelineVersion
{
Naive,
ComputeFriendly,
MemFriendly,
ComputeFriendlyDoubleLDS,
ComputeFriendlyDoubleGlobalPrefetch
};
enum class GemmPipelineScheduler
{
Intrawave,
Interwave
};
enum class ConvolutionSpecialization
{
Default,
Filter1x1Pad0,
Filter1x1Stride1Pad0,
Filter3x3
};
enum class MFMAInstructionSize
{
M16N16,
M32N32
};
// No requirements yet for a ConvAlogorithm concept.
template <typename T>
concept ConvAlgorithm = std::is_class_v<T>;

View File

@@ -1,6 +1,7 @@
#pragma once
// #include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
// #include
// "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_xdl_comp_instance.hpp"
#include <ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp>
#include <ck_tile/builder/conv_signature.hpp>
#include <ck_tile/builder/conv_algorithm.hpp>
@@ -65,15 +66,6 @@ struct ConvSpec
ck::tensor_operation::device::GemmSpecialization::MNKPadding;
};
// Store M,N,K values.
template <typename T>
struct MNK
{
T m = 0;
T n = 0;
T k = 0;
};
// Block info for a convlution.
struct ConvBlock
{
@@ -81,6 +73,23 @@ struct ConvBlock
MNK<int> per_block;
};
template <ConvAlgorithm Algo>
constexpr ConvBlock SetThreadBlockInfo()
{
if constexpr(HasThreadBlockInfo<Algo>)
{
constexpr auto& TB = Algo::THREAD_BLOCK;
return ConvBlock{
.block_size = TB.block_size,
.per_block = {.m = TB.sub_matrix.m, .n = TB.sub_matrix.n, .k = TB.sub_matrix.k}};
}
// Default values if thread block info isn't specified.
return ConvBlock{
.block_size = 256,
.per_block = {.m = 256, .n = 256, .k = 32},
};
}
// Convolution tuning parameters.
struct ConvTuning
{
@@ -119,17 +128,14 @@ template <ConvSignature Signature, ConvAlgorithm Algorithm, auto Version>
struct GroupedConvForwardXldCShuffleFactoryV3
{
static constexpr int SPATIAL_DIM = Signature::SPATIAL_DIM;
using Layouts = ConvTensorLayouts<Signature::LAYOUT>;
using Types = ConvTensorTypes<Signature::DATA_TYPE>;
using Ops = ConvPassThroughOps;
using Layouts = ConvTensorLayouts<Signature::LAYOUT>;
using Types = ConvTensorTypes<Signature::DATA_TYPE>;
using Ops = ConvPassThroughOps;
static constexpr ConvSpec SPECIALIZATION{
.conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
};
static constexpr ConvBlock BLOCK{
.block_size = 256,
.per_block = {.m = 256, .n = 256, .k = 32},
};
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<Algorithm>();
static constexpr ConvTuning TUNING{
.ak1 = 8,
.ak2 = 8,
@@ -165,53 +171,54 @@ struct GroupedConvForwardXldCShuffleFactoryV3
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto PIPELINE_VERSION = ck::BlockGemmPipelineVersion::v4;
// The convlution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataTYpe,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
TUNING.ak1,
TUNING.ak2,
TUNING.m_per_xdl,
TUNING.n_per_dxl,
TUNING.m_xdl_per_wave,
TUNING.n_xdl_per_wave,
ToSequence<A_BLOCK_TRANSFER.thread_cluster_lengths>,
ToSequence<A_BLOCK_TRANSFER.thread_cluster_order>,
ToSequence<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scaler_per_vector,
A_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
A_BLOCK_TRANSFER.add_extra,
ToSequence<B_BLOCK_TRANSFER.thread_cluster_lengths>,
ToSequence<B_BLOCK_TRANSFER.thread_cluster_order>,
ToSequence<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scaler_per_vector,
B_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
B_BLOCK_TRANSFER.add_extra,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
ToSequence<C_BLOCK_TRANSFER.cluster_lengths>,
C_BLOCK_TRANSFER.scaler_per_vector,
PIPELINE_SCHEDULER,
PIPELINE_VERSION>;
using Instance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
SPATIAL_DIM,
typename Layouts::ALayout,
typename Layouts::BLayout,
typename Layouts::DsLayout,
typename Layouts::ELayout,
typename Types::ADataType,
typename Types::BDataType,
typename Types::AccDataType,
typename Types::CShuffleDataType,
typename Types::DsDataTypes,
typename Types::EDataTYpe,
typename Ops::AElementwiseOp,
typename Ops::BElementwiseOp,
typename Ops::CDEElementwiseOp,
SPECIALIZATION.conv_spec,
SPECIALIZATION.gemm_spec,
BLOCK.block_size,
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
TUNING.ak1,
TUNING.ak2,
TUNING.m_per_xdl,
TUNING.n_per_dxl,
TUNING.m_xdl_per_wave,
TUNING.n_xdl_per_wave,
ToSequence<A_BLOCK_TRANSFER.thread_cluster_lengths>,
ToSequence<A_BLOCK_TRANSFER.thread_cluster_order>,
ToSequence<A_BLOCK_TRANSFER.src_access_order>,
A_BLOCK_TRANSFER.src_vector_dim,
A_BLOCK_TRANSFER.src_scaler_per_vector,
A_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
A_BLOCK_TRANSFER.add_extra,
ToSequence<B_BLOCK_TRANSFER.thread_cluster_lengths>,
ToSequence<B_BLOCK_TRANSFER.thread_cluster_order>,
ToSequence<B_BLOCK_TRANSFER.src_access_order>,
B_BLOCK_TRANSFER.src_vector_dim,
B_BLOCK_TRANSFER.src_scaler_per_vector,
B_BLOCK_TRANSFER.dest_scaler_per_vector_k1,
B_BLOCK_TRANSFER.add_extra,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
ToSequence<C_BLOCK_TRANSFER.cluster_lengths>,
C_BLOCK_TRANSFER.scaler_per_vector,
PIPELINE_SCHEDULER,
PIPELINE_VERSION>;
};
} // namespace ck_tile::builder

View File

@@ -41,4 +41,23 @@ TEST(ConvBuilderTest, TestInstance)
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, "
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
}
struct ConvFwdXdlBf16CompInstances2xAlgorithm0
{
static constexpr ckb::ThreadBlock THREAD_BLOCK{
.block_size = 256,
.sub_matrix = {.m = 256, .n = 256, .k = 32},
};
};
TEST(ConvBuilderTest, TestInstance0)
{
using Builder =
ckb::ConvBuilder<FwdConvSignature, ConvFwdXdlBf16CompInstances2xAlgorithm0, API_VERSION>;
EXPECT_EQ(
Builder::Instance::TypeString(),
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, "
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
}
} // namespace