mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 19:57:40 +00:00
Add block transfer paramters to builder.
These are very hard to test in the kernel class, so just test the values in the factory.
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
|
||||
#include <type_traits>
|
||||
#include <concepts>
|
||||
#include <array>
|
||||
|
||||
namespace ck_tile::builder {
|
||||
|
||||
@@ -31,7 +32,6 @@ struct ThreadBlock
|
||||
// Size of the submatrix problem in a thread block.
|
||||
MNK<int> sub_matrix;
|
||||
};
|
||||
|
||||
static_assert(ThreadBlockInfo<ThreadBlock>);
|
||||
|
||||
// Concept to check if struct provides thread block info.
|
||||
@@ -58,7 +58,6 @@ struct ConvTuningParams
|
||||
int m_xdl_per_wave = 0;
|
||||
int n_xdl_per_wave = 0;
|
||||
};
|
||||
|
||||
static_assert(ConvTuningInfo<ConvTuningParams>);
|
||||
|
||||
// Concept to check if a struct provides convolution tuning info.
|
||||
@@ -67,6 +66,77 @@ concept HasConvTuningInfo = requires {
|
||||
{ T::tuning_params } -> ConvTuningInfo;
|
||||
};
|
||||
|
||||
// Concept for A block transfer thread cluster lengths.
|
||||
template <typename T>
|
||||
concept BlockATransferLengths = requires(T t) {
|
||||
{ t.k0 } -> std::convertible_to<int>;
|
||||
{ t.m } -> std::convertible_to<int>;
|
||||
{ t.k1 } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
// Describe A block transfer thread cluster lengths.
|
||||
struct BlockATransferLengthsInfo
|
||||
{
|
||||
int k0;
|
||||
int m;
|
||||
int k1;
|
||||
};
|
||||
static_assert(BlockATransferLengths<BlockATransferLengthsInfo>);
|
||||
|
||||
// Concept for B block transfer thread cluster lengths.
|
||||
template <typename T>
|
||||
concept BlockBTransferLengths = requires(T t) {
|
||||
{ t.k0 } -> std::convertible_to<int>;
|
||||
{ t.n } -> std::convertible_to<int>;
|
||||
{ t.k1 } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
// Describe B block transfer thread cluster lengths.
|
||||
struct BlockBTransferLengthsInfo
|
||||
{
|
||||
int k0;
|
||||
int n;
|
||||
int k1;
|
||||
};
|
||||
static_assert(BlockBTransferLengths<BlockBTransferLengthsInfo>);
|
||||
|
||||
// Concept for C block transfer thread cluster lengths.
|
||||
template <typename T>
|
||||
concept BlockCTransferLengths = requires(T t) {
|
||||
{ t.m_block } -> std::convertible_to<int>;
|
||||
{ t.m_wave_per_xdl } -> std::convertible_to<int>;
|
||||
{ t.n_block } -> std::convertible_to<int>;
|
||||
{ t.n_wave_per_xdl } -> std::convertible_to<int>;
|
||||
};
|
||||
|
||||
// Describe C block transfer thread cluster lengths.
|
||||
struct BlockCTransferLengthsInfo
|
||||
{
|
||||
int m_block;
|
||||
int m_wave_per_xdl;
|
||||
int n_block;
|
||||
int n_wave_per_xdl;
|
||||
};
|
||||
static_assert(BlockBTransferLengths<BlockBTransferLengthsInfo>);
|
||||
|
||||
// Concept to check if a struct provides A Block tranfer info.
|
||||
template <typename T>
|
||||
concept HasABlockTransferInfo = requires(T t) {
|
||||
{ T::block_transfer.thread_cluster_lengths_a } -> BlockATransferLengths;
|
||||
};
|
||||
|
||||
// Concept to check if a struct provides B Block tranfer info.
|
||||
template <typename T>
|
||||
concept HasBBlockTransferInfo = requires(T t) {
|
||||
{ T::block_transfer.thread_cluster_lengths_b } -> BlockBTransferLengths;
|
||||
};
|
||||
|
||||
// Concept to check if a struct provides C Block tranfer info.
|
||||
template <typename T>
|
||||
concept HasCBlockTransferInfo = requires(T t) {
|
||||
{ T::block_transfer.thread_cluster_lengths_c } -> BlockCTransferLengths;
|
||||
};
|
||||
|
||||
// No requirements yet for a ConvAlogorithm concept.
|
||||
template <typename T>
|
||||
concept ConvAlgorithm = std::is_class_v<T>;
|
||||
|
||||
@@ -129,7 +129,7 @@ constexpr ConvTuning SetConvTuningInfo()
|
||||
};
|
||||
}
|
||||
|
||||
// Block tranfser paramters for A or B tensor.
|
||||
// Block transfer paramters for A or B tensor.
|
||||
struct BlockTransfer
|
||||
{
|
||||
ck::Array<int, 3> thread_cluster_lengths = {0, 0, 0}; // k0, m, k1
|
||||
@@ -144,12 +144,79 @@ struct BlockTransfer
|
||||
// Block transfer parameters for C tensor.
|
||||
struct CBlockTransfer
|
||||
{
|
||||
int m_xdl_per_wave_per_shuffle = 0;
|
||||
int n_xdl_per_wave_per_shuffle = 0;
|
||||
ck::Array<int, 4> cluster_lengths = {0, 0, 0, 0};
|
||||
int scaler_per_vector = 8;
|
||||
int m_xdl_per_wave_per_shuffle = 0;
|
||||
int n_xdl_per_wave_per_shuffle = 0;
|
||||
ck::Array<int, 4> thread_cluster_lengths = {0, 0, 0, 0};
|
||||
int scaler_per_vector = 8;
|
||||
};
|
||||
|
||||
template <ConvAlgorithm auto ALGORITHM>
|
||||
constexpr BlockTransfer SetABlockTransfer()
|
||||
{
|
||||
BlockTransfer block_transfer{
|
||||
.thread_cluster_lengths = {4, 64, 1},
|
||||
.thread_cluster_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.src_vector_dim = 2,
|
||||
.src_scaler_per_vector = 8,
|
||||
.dest_scaler_per_vector_k1 = 8,
|
||||
.add_extra = 0,
|
||||
};
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(HasABlockTransferInfo<AlgorithmType>)
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_lengths_a;
|
||||
block_transfer.thread_cluster_lengths = {TCL.k0, TCL.m, TCL.k1};
|
||||
}
|
||||
// Default.
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithm auto ALGORITHM>
|
||||
constexpr BlockTransfer SetBBlockTransfer()
|
||||
{
|
||||
BlockTransfer block_transfer{
|
||||
.thread_cluster_lengths = {4, 64, 1},
|
||||
.thread_cluster_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.src_vector_dim = 2,
|
||||
.src_scaler_per_vector = 8,
|
||||
.dest_scaler_per_vector_k1 = 8,
|
||||
.add_extra = 0,
|
||||
};
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(HasBBlockTransferInfo<AlgorithmType>)
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_lengths_b;
|
||||
block_transfer.thread_cluster_lengths = {TCL.k0, TCL.n, TCL.k1};
|
||||
}
|
||||
// Default.
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithm auto ALGORITHM>
|
||||
constexpr CBlockTransfer SetCBlockTransfer()
|
||||
{
|
||||
CBlockTransfer block_transfer{
|
||||
.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.thread_cluster_lengths = {1, 32, 1, 8},
|
||||
.scaler_per_vector = 8,
|
||||
};
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
if constexpr(HasCBlockTransferInfo<AlgorithmType>)
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_lengths_c;
|
||||
block_transfer.thread_cluster_lengths = {
|
||||
TCL.m_block,
|
||||
TCL.m_wave_per_xdl,
|
||||
TCL.n_block,
|
||||
TCL.n_wave_per_xdl,
|
||||
};
|
||||
}
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
// Factory builds an instance of a grouped convolution kernel.
|
||||
template <ConvSignature Signature, ConvAlgorithm auto ALGORITHM, auto Version>
|
||||
requires SupportedVersion<Version>
|
||||
@@ -163,34 +230,13 @@ struct GroupedConvForwardXldCShuffleFactoryV3
|
||||
.conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
|
||||
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
|
||||
};
|
||||
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr ConvTuning TUNING = SetConvTuningInfo<ALGORITHM>();
|
||||
static constexpr BlockTransfer A_BLOCK_TRANSFER{
|
||||
.thread_cluster_lengths = {4, 64, 1},
|
||||
.thread_cluster_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.src_vector_dim = 2,
|
||||
.src_scaler_per_vector = 8,
|
||||
.dest_scaler_per_vector_k1 = 8,
|
||||
.add_extra = 0,
|
||||
};
|
||||
static constexpr BlockTransfer B_BLOCK_TRANSFER{
|
||||
.thread_cluster_lengths = {4, 64, 1},
|
||||
.thread_cluster_order = {1, 0, 2},
|
||||
.src_access_order = {1, 0, 2},
|
||||
.src_vector_dim = 2,
|
||||
.src_scaler_per_vector = 8,
|
||||
.dest_scaler_per_vector_k1 = 8,
|
||||
.add_extra = 0,
|
||||
};
|
||||
static constexpr CBlockTransfer C_BLOCK_TRANSFER{
|
||||
.m_xdl_per_wave_per_shuffle = 1,
|
||||
.n_xdl_per_wave_per_shuffle = 1,
|
||||
.cluster_lengths = {1, 32, 1, 8},
|
||||
.scaler_per_vector = 8,
|
||||
};
|
||||
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PIPELINE_VERSION = ck::BlockGemmPipelineVersion::v4;
|
||||
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr ConvTuning TUNING = SetConvTuningInfo<ALGORITHM>();
|
||||
static constexpr BlockTransfer A_BLOCK_TRANSFER = SetABlockTransfer<ALGORITHM>();
|
||||
static constexpr BlockTransfer B_BLOCK_TRANSFER = SetBBlockTransfer<ALGORITHM>();
|
||||
static constexpr CBlockTransfer C_BLOCK_TRANSFER = SetCBlockTransfer<ALGORITHM>();
|
||||
static constexpr auto PIPELINE_SCHEDULER = ck::BlockGemmPipelineScheduler::Intrawave;
|
||||
static constexpr auto PIPELINE_VERSION = ck::BlockGemmPipelineVersion::v4;
|
||||
// The convlution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3< //
|
||||
@@ -236,7 +282,7 @@ struct GroupedConvForwardXldCShuffleFactoryV3
|
||||
B_BLOCK_TRANSFER.add_extra,
|
||||
C_BLOCK_TRANSFER.m_xdl_per_wave_per_shuffle,
|
||||
C_BLOCK_TRANSFER.n_xdl_per_wave_per_shuffle,
|
||||
ToSequence<C_BLOCK_TRANSFER.cluster_lengths>,
|
||||
ToSequence<C_BLOCK_TRANSFER.thread_cluster_lengths>,
|
||||
C_BLOCK_TRANSFER.scaler_per_vector,
|
||||
PIPELINE_SCHEDULER,
|
||||
PIPELINE_VERSION>;
|
||||
|
||||
@@ -36,16 +36,31 @@ struct FwdConvAlgorithm
|
||||
{
|
||||
ckb::ThreadBlock thread_block;
|
||||
ckb::ConvTuningParams tuning_params;
|
||||
struct BlockTransfer
|
||||
{
|
||||
ckb::BlockATransferLengthsInfo thread_cluster_lengths_a;
|
||||
ckb::BlockBTransferLengthsInfo thread_cluster_lengths_b;
|
||||
ckb::BlockCTransferLengthsInfo thread_cluster_lengths_c;
|
||||
} block_transfer;
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithm<FwdConvAlgorithm>);
|
||||
static_assert(ckb::HasThreadBlockInfo<FwdConvAlgorithm>);
|
||||
static_assert(ckb::HasConvTuningInfo<FwdConvAlgorithm>);
|
||||
static_assert(ckb::HasABlockTransferInfo<FwdConvAlgorithm>);
|
||||
static_assert(ckb::HasBBlockTransferInfo<FwdConvAlgorithm>);
|
||||
static_assert(ckb::HasCBlockTransferInfo<FwdConvAlgorithm>);
|
||||
|
||||
TEST(ConvBuilderTest, TestConvFwdXdlBf16CompInstances2xInstance0)
|
||||
{
|
||||
static constexpr FwdConvAlgorithm algorithm{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 128, .k = 64}},
|
||||
.tuning_params{.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2}};
|
||||
.tuning_params{.ak1 = 16, .bk1 = 16, .m_xdl_per_wave = 2, .n_xdl_per_wave = 2},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
}};
|
||||
using Builder = ckb::ConvBuilder<FwdConvSignature, algorithm, API_VERSION>;
|
||||
EXPECT_EQ(
|
||||
Builder::Instance::TypeString(),
|
||||
@@ -53,12 +68,28 @@ TEST(ConvBuilderTest, TestConvFwdXdlBf16CompInstances2xInstance0)
|
||||
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
|
||||
EXPECT_EQ(Builder::factory::TUNING.ak1, 16);
|
||||
EXPECT_EQ(Builder::factory::TUNING.bk1, 16);
|
||||
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[0], 4);
|
||||
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[1], 64);
|
||||
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[2], 1);
|
||||
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[0], 4);
|
||||
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[1], 64);
|
||||
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[2], 1);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[0], 1);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[1], 32);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[2], 1);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[3], 8);
|
||||
}
|
||||
TEST(ConvBuilderTest, GroupedConvFwdXdlBf16CompInstance0)
|
||||
{
|
||||
static constexpr FwdConvAlgorithm algorithm{
|
||||
.thread_block{.block_size = 256, .sub_matrix = {.m = 256, .n = 256, .k = 32}},
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4}};
|
||||
.tuning_params{.ak1 = 8, .bk1 = 8, .m_xdl_per_wave = 4, .n_xdl_per_wave = 4},
|
||||
.block_transfer{
|
||||
.thread_cluster_lengths_a = {.k0 = 4, .m = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_b = {.k0 = 4, .n = 64, .k1 = 1},
|
||||
.thread_cluster_lengths_c =
|
||||
{.m_block = 1, .m_wave_per_xdl = 32, .n_block = 1, .n_wave_per_xdl = 8},
|
||||
}};
|
||||
using Builder = ckb::ConvBuilder<FwdConvSignature, algorithm, API_VERSION>;
|
||||
EXPECT_EQ(
|
||||
Builder::Instance::TypeString(),
|
||||
@@ -66,6 +97,16 @@ TEST(ConvBuilderTest, GroupedConvFwdXdlBf16CompInstance0)
|
||||
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
|
||||
EXPECT_EQ(Builder::factory::TUNING.ak1, 8);
|
||||
EXPECT_EQ(Builder::factory::TUNING.bk1, 8);
|
||||
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[0], 4);
|
||||
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[1], 64);
|
||||
EXPECT_EQ(Builder::factory::A_BLOCK_TRANSFER.thread_cluster_lengths[2], 1);
|
||||
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[0], 4);
|
||||
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[1], 64);
|
||||
EXPECT_EQ(Builder::factory::B_BLOCK_TRANSFER.thread_cluster_lengths[2], 1);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[0], 1);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[1], 32);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[2], 1);
|
||||
EXPECT_EQ(Builder::factory::C_BLOCK_TRANSFER.thread_cluster_lengths[3], 8);
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
Reference in New Issue
Block a user