Add block transfer paramters to builder.

These are very hard to test in the kernel class, so just test the values in the factory.
This commit is contained in:
John Shumway
2025-09-02 23:08:32 +00:00
parent 97660c64e5
commit 6a513e1a7f
3 changed files with 195 additions and 38 deletions

View File

@@ -2,6 +2,7 @@
#include <type_traits>
#include <concepts>
#include <array>
namespace ck_tile::builder {
@@ -31,7 +32,6 @@ struct ThreadBlock
// Size of the submatrix problem in a thread block.
MNK<int> sub_matrix;
};
static_assert(ThreadBlockInfo<ThreadBlock>);
// Concept to check if struct provides thread block info.
@@ -58,7 +58,6 @@ struct ConvTuningParams
int m_xdl_per_wave = 0;
int n_xdl_per_wave = 0;
};
static_assert(ConvTuningInfo<ConvTuningParams>);
// Concept to check if a struct provides convolution tuning info.
@@ -67,6 +66,77 @@ concept HasConvTuningInfo = requires {
{ T::tuning_params } -> ConvTuningInfo;
};
// Concept for A block transfer thread cluster lengths.
template <typename T>
concept BlockATransferLengths = requires(T t) {
{ t.k0 } -> std::convertible_to<int>;
{ t.m } -> std::convertible_to<int>;
{ t.k1 } -> std::convertible_to<int>;
};
// Describe A block transfer thread cluster lengths.
struct BlockATransferLengthsInfo
{
int k0;
int m;
int k1;
};
static_assert(BlockATransferLengths<BlockATransferLengthsInfo>);
// Concept for B block transfer thread cluster lengths.
template <typename T>
concept BlockBTransferLengths = requires(T t) {
{ t.k0 } -> std::convertible_to<int>;
{ t.n } -> std::convertible_to<int>;
{ t.k1 } -> std::convertible_to<int>;
};
// Describe B block transfer thread cluster lengths.
struct BlockBTransferLengthsInfo
{
int k0;
int n;
int k1;
};
static_assert(BlockBTransferLengths<BlockBTransferLengthsInfo>);
// Concept for C block transfer thread cluster lengths.
template <typename T>
concept BlockCTransferLengths = requires(T t) {
{ t.m_block } -> std::convertible_to<int>;
{ t.m_wave_per_xdl } -> std::convertible_to<int>;
{ t.n_block } -> std::convertible_to<int>;
{ t.n_wave_per_xdl } -> std::convertible_to<int>;
};
// Describe C block transfer thread cluster lengths.
struct BlockCTransferLengthsInfo
{
int m_block;
int m_wave_per_xdl;
int n_block;
int n_wave_per_xdl;
};
static_assert(BlockBTransferLengths<BlockBTransferLengthsInfo>);
// Concept to check if a struct provides A Block tranfer info.
template <typename T>
concept HasABlockTransferInfo = requires(T t) {
{ T::block_transfer.thread_cluster_lengths_a } -> BlockATransferLengths;
};
// Concept to check if a struct provides B Block tranfer info.
template <typename T>
concept HasBBlockTransferInfo = requires(T t) {
{ T::block_transfer.thread_cluster_lengths_b } -> BlockBTransferLengths;
};
// Concept to check if a struct provides C Block tranfer info.
template <typename T>
concept HasCBlockTransferInfo = requires(T t) {
{ T::block_transfer.thread_cluster_lengths_c } -> BlockCTransferLengths;
};
// No requirements yet for a ConvAlogorithm concept.
template <typename T>
concept ConvAlgorithm = std::is_class_v<T>;

View File

@@ -129,7 +129,7 @@ constexpr ConvTuning SetConvTuningInfo()
};
}
// Block tranfser paramters for A or B tensor.
// Block transfer paramters for A or B tensor.
struct BlockTransfer
{
ck::Array<int, 3> thread_cluster_lengths = {0, 0, 0}; // k0, m, k1
@@ -144,12 +144,79 @@ struct BlockTransfer
// Block transfer parameters for C tensor.
struct CBlockTransfer
{
int m_xdl_per_wave_per_shuffle = 0;
int n_xdl_per_wave_per_shuffle = 0;
ck::Array<int, 4> cluster_lengths = {0, 0, 0, 0};
int scaler_per_vector = 8;
int m_xdl_per_wave_per_shuffle = 0;
int n_xdl_per_wave_per_shuffle = 0;
ck::Array<int, 4> thread_cluster_lengths = {0, 0, 0, 0};
int scaler_per_vector = 8;
};
template <ConvAlgorithm auto ALGORITHM>
constexpr BlockTransfer SetABlockTransfer()
{
BlockTransfer block_transfer{
.thread_cluster_lengths = {4, 64, 1},
.thread_cluster_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
.src_vector_dim = 2,
.src_scaler_per_vector = 8,
.dest_scaler_per_vector_k1 = 8,
.add_extra = 0,
};
using AlgorithmType = decltype(ALGORITHM);
if constexpr(HasABlockTransferInfo<AlgorithmType>)
{
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_lengths_a;
block_transfer.thread_cluster_lengths = {TCL.k0, TCL.m, TCL.k1};
}
// Default.
return block_transfer;
}
template <ConvAlgorithm auto ALGORITHM>
constexpr BlockTransfer SetBBlockTransfer()
{
BlockTransfer block_transfer{
.thread_cluster_lengths = {4, 64, 1},
.thread_cluster_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
.src_vector_dim = 2,
.src_scaler_per_vector = 8,
.dest_scaler_per_vector_k1 = 8,
.add_extra = 0,
};
using AlgorithmType = decltype(ALGORITHM);
if constexpr(HasBBlockTransferInfo<AlgorithmType>)
{
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_lengths_b;
block_transfer.thread_cluster_lengths = {TCL.k0, TCL.n, TCL.k1};
}
// Default.
return block_transfer;
}
template <ConvAlgorithm auto ALGORITHM>
constexpr CBlockTransfer SetCBlockTransfer()
{
CBlockTransfer block_transfer{
.m_xdl_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.thread_cluster_lengths = {1, 32, 1, 8},
.scaler_per_vector = 8,
};
using AlgorithmType = decltype(ALGORITHM);
if constexpr(HasCBlockTransferInfo<AlgorithmType>)
{
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_lengths_c;
block_transfer.thread_cluster_lengths = {
TCL.m_block,
TCL.m_wave_per_xdl,
TCL.n_block,
TCL.n_wave_per_xdl,
};
}
return block_transfer;
}
// Factory builds an instance of a grouped convolution kernel.
template <ConvSignature Signature, ConvAlgorithm auto ALGORITHM, auto Version>
requires SupportedVersion<Version>
@@ -163,34 +230,13 @@ struct GroupedConvForwardXldCShuffleFactoryV3
.conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
};
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
static constexpr ConvTuning TUNING = SetConvTuningInfo<ALGORITHM>();
static constexpr BlockTransfer A_BLOCK_TRANSFER{
.thread_cluster_lengths = {4, 64, 1},
.thread_cluster_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
.src_vector_dim = 2,
.src_scaler_per_vector = 8,
.dest_scaler_per_vector_k1 = 8,
.add_extra = 0,
};
static constexpr BlockTransfer B_BLOCK_TRANSFER{
.thread_cluster_lengths = {4, 64, 1},
.thread_cluster_order = {1, 0, 2},
.src_access_order = {1, 0, 2},
.src_vector_dim = 2,
.src_scaler_per_vector = 8,
.dest_scaler_per_vector_k1 = 8,
.add_extra = 0,
};
static constexpr CBlockTransfer C_BLOCK_TRANSFER{
.m_xdl_per_wave_per_shuffle = 1,
.n_xdl_per_wave_per_shuffle = 1,
.cluster_lengths = {1, 32, 1, 8},
.scaler_per_vector = 8,
};
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto PIPELINE_VERSION = ck::BlockGemmPipelineVersion::v4;
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
static constexpr ConvTuning TUNING = SetConvTuningInfo<ALGORITHM>();
static constexpr BlockTransfer A_BLOCK_TRANSFER = SetABlockTransfer<ALGORITHM>();
static constexpr BlockTransfer B_BLOCK_TRANSFER = SetBBlockTransfer<ALGORITHM>();
static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer<ALGORITHM>();
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
static constexpr auto PIPELINE_VERSION = ck::BlockGemmPipelineVersion::v4;
// The convlution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
@@ -236,7 +282,7 @@ struct GroupedConvForwardXldCShuffleFactoryV3
B_BLOCK_TRANSFER.add_extra,
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
ToSequence<C_BLOCK_TRANSFER.cluster_lengths>,
ToSequence<C_BLOCK_TRANSFER.thread_cluster_lengths>,
C_BLOCK_TRANSFER.scaler_per_vector,
PIPELINE_SCHEDULER,
PIPELINE_VERSION>;

View File

@@ -36,16 +36,31 @@ struct FwdConvAlgorithm
{
ckb::ThreadBlock thread_block;
ckb::ConvTuningParams tuning_params;
struct BlockTransfer
{
ckb::BlockATransferLengthsInfo thread_cluster_lengths_a;
ckb::BlockBTransferLengthsInfo thread_cluster_lengths_b;
ckb::BlockCTransferLengthsInfo thread_cluster_lengths_c;
} block_transfer;
};
static_assert(ckb::ConvAlgorithm<FwdConvAlgorithm>);
static_assert(ckb::HasThreadBlockInfo<FwdConvAlgorithm>);
static_assert(ckb::HasConvTuningInfo<FwdConvAlgorithm>);
static_assert(ckb::HasABlockTransferInfo<FwdConvAlgorithm>);
static_assert(ckb::HasBBlockTransferInfo<FwdConvAlgorithm>);
static_assert(ckb::HasCBlockTransferInfo<FwdConvAlgorithm>);
TEST(ConvBuilderTest, TestConvFwdXdlBf16CompInstances2xInstance0)
{
static constexpr FwdConvAlgorithm algorithm{
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 64}},
.tuning_params{.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
.tuning_params{.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
}};
using Builder = ckb::ConvBuilder<FwdConvSignature, algorithm, API_VERSION>;
EXPECT_EQ(
Builder::Instance::TypeString(),
@@ -53,12 +68,28 @@ TEST(ConvBuilderTest, TestConvFwdXdlBf16CompInstances2xInstance0)
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
EXPECT_EQ(Builder::factory::TUNING.ak1, 16);
EXPECT_EQ(Builder::factory::TUNING.bk1, 16);
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[0], 4);
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[1], 64);
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[2], 1);
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[0], 4);
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[1], 64);
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[2], 1);
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[0], 1);
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[1], 32);
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[2], 1);
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[3], 8);
}
TEST(ConvBuilderTest, GroupedConvFwdXdlBf16CompInstance0)
{
static constexpr FwdConvAlgorithm algorithm{
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}},
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
.block_transfer{
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
.thread_cluster_lengths_c =
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
}};
using Builder = ckb::ConvBuilder<FwdConvSignature, algorithm, API_VERSION>;
EXPECT_EQ(
Builder::Instance::TypeString(),
@@ -66,6 +97,16 @@ TEST(ConvBuilderTest, GroupedConvFwdXdlBf16CompInstance0)
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
EXPECT_EQ(Builder::factory::TUNING.ak1, 8);
EXPECT_EQ(Builder::factory::TUNING.bk1, 8);
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[0], 4);
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[1], 64);
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[2], 1);
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[0], 4);
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[1], 64);
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[2], 1);
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[0], 1);
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[1], 32);
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[2], 1);
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[3], 8);
};
} // namespace