Add tuning parameters to builder.

Add support for setting ak1, bk1, m_xdl_per_wave, and n_xdl_per_wave.

Note: It's difficult to test ak1 and bk1, since the values are not stored in the class.
This commit is contained in:
John Shumway
2025-09-02 16:32:32 +00:00
parent c0f5f5a20e
commit f8b790dfd1
3 changed files with 99 additions and 21 deletions

View File

@@ -5,15 +5,15 @@
namespace ck_tile::builder {
// Convenience struct for a tuple of m, n, and k values.
// Convenience struct for a tuple of m, n, and k values.
template <typename T>
struct MNK {
struct MNK
{
T m{};
T n{};
T k{};
};
// Concept for thread block info for a GEMM problem.
template <typename T>
concept ThreadBlockInfo = requires(T t) {
@@ -23,9 +23,9 @@ concept ThreadBlockInfo = requires(T t) {
{ t.sub_matrix.k } -> std::convertible_to<int>;
};
// Describe a thread block for a GEMM.
struct ThreadBlock {
struct ThreadBlock
{
// Thread block size.
int block_size;
// Size of the submatrix problem in a thread block.
@@ -40,6 +40,32 @@ concept HasThreadBlockInfo = requires {
{ T::THREAD_BLOCK } -> ThreadBlockInfo;
};
// Concept for tuning parameters for a convolution problem.
template <typename T>
concept ConvTuningInfo = requires(T t) {
{ t.ak1 } -> std::convertible_to<int>;
{ t.bk1 } -> std::convertible_to<int>;
{ t.m_xdl_per_wave } -> std::convertible_to<int>;
{ t.n_xdl_per_wave } -> std::convertible_to<int>;
};
// Describe some convolution tuning parameters.
struct ConvTuningParams
{
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
int ak1 = 0;
int bk1 = 0;
int m_xdl_per_wave = 0;
int n_xdl_per_wave = 0;
};
static_assert(ConvTuningInfo<ConvTuningParams>);
// Concept to check if a struct provides convolution tuning info.
template <typename T>
concept HasConvTuningInfo = requires {
{ T::TUNING_PARAMS } -> ConvTuningInfo;
};
// No requirements yet for a ConvAlogorithm concept.
template <typename T>

View File

@@ -94,13 +94,39 @@ constexpr ConvBlock SetThreadBlockInfo()
struct ConvTuning
{
int ak1 = 0;
int ak2 = 0;
int bk1 = 0;
int m_per_xdl = 0;
int n_per_dxl = 0;
int m_xdl_per_wave = 0;
int n_xdl_per_wave = 0;
};
template <ConvAlgorithm Algo>
constexpr ConvTuning SetConvTuningInfo()
{
if constexpr(HasConvTuningInfo<Algo>)
{
constexpr auto TI = Algo::TUNING_PARAMS;
return ConvTuning{
.ak1 = TI.ak1,
.bk1 = TI.bk1,
.m_per_xdl = 32,
.n_per_dxl = 32,
.m_xdl_per_wave = TI.m_xdl_per_wave,
.n_xdl_per_wave = TI.n_xdl_per_wave,
};
}
// Default values.
return ConvTuning{
.ak1 = 8,
.bk1 = 8,
.m_per_xdl = 32,
.n_per_dxl = 32,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4,
};
}
// Block tranfser paramters for A or B tensor.
struct BlockTransfer
{
@@ -135,15 +161,8 @@ struct GroupedConvForwardXldCShuffleFactoryV3
.conv_spec = ck::tensor_operation::device::ConvolutionForwardSpecialization::Default,
.gemm_spec = ck::tensor_operation::device::GemmSpecialization::MNKPadding,
};
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<Algorithm>();
static constexpr ConvTuning TUNING{
.ak1 = 8,
.ak2 = 8,
.m_per_xdl = 32,
.n_per_dxl = 32,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4,
};
static constexpr ConvBlock BLOCK = SetThreadBlockInfo<Algorithm>();
static constexpr ConvTuning TUNING = SetConvTuningInfo<Algorithm>();
static constexpr BlockTransfer A_BLOCK_TRANSFER{
.thread_cluster_lengths = {4, 64, 1},
.thread_cluster_order = {1, 0, 2},
@@ -194,7 +213,7 @@ struct GroupedConvForwardXldCShuffleFactoryV3
BLOCK.per_block.n,
BLOCK.per_block.k,
TUNING.ak1,
TUNING.ak2,
TUNING.bk1,
TUNING.m_per_xdl,
TUNING.n_per_dxl,
TUNING.m_xdl_per_wave,

View File

@@ -17,26 +17,36 @@ static_assert(ckb::ConvSignature<FwdConvSignature>);
struct FwdConvAlgorithm
{
// TODO: Add algorithm info.
};
static_assert(ckb::ConvAlgorithm<FwdConvAlgorithm>);
static constexpr char API_VERSION[] = "0.1.0";
using FwdConvBuilder = ckb::ConvBuilder<FwdConvSignature, FwdConvAlgorithm, API_VERSION>;
TEST(ConvBuilderTest, TestDefaultInstance)
{
using Builder = ckb::ConvBuilder<FwdConvSignature, FwdConvAlgorithm, API_VERSION>;
EXPECT_EQ(
FwdConvBuilder::Instance::TypeString(),
Builder::Instance::TypeString(),
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, "
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
}
[[maybe_unused]] static constexpr ckb::ThreadBlock THREAD_BLOCK_256_256_256_32{
.block_size = 256,
.sub_matrix = {.m = 256, .n = 256, .k = 32},
};
struct ConvFwdXdlBf16CompInstances2xAlgorithm0
{
static constexpr ckb::ThreadBlock THREAD_BLOCK{
.block_size = 256,
.sub_matrix = {.m = 256, .n = 256, .k = 32},
.sub_matrix = {.m = 256, .n = 128, .k = 64},
};
static constexpr ckb::ConvTuningParams TUNING_PARAMS{
.ak1 = 16,
.bk1 = 16,
.m_xdl_per_wave = 2,
.n_xdl_per_wave = 2,
};
};
@@ -46,8 +56,31 @@ TEST(ConvBuilderTest, TestConvFwdXdlBf16CompInstances2xInstance0)
ckb::ConvBuilder<FwdConvSignature, ConvFwdXdlBf16CompInstances2xAlgorithm0, API_VERSION>;
EXPECT_EQ(
Builder::Instance::TypeString(),
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, "
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 128, 64, Default, 32, 32, 2, 2, "
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
}
struct ConvFwdXdlBf16CompAlgorithm0
{
static constexpr ckb::ThreadBlock THREAD_BLOCK{
.block_size = 256,
.sub_matrix = {.m = 256, .n = 256, .k = 32},
};
static constexpr ckb::ConvTuningParams TUNING_PARAMS{
.ak1 = 8,
.bk1 = 8,
.m_xdl_per_wave = 4,
.n_xdl_per_wave = 4,
};
};
TEST(ConvBuilderTest, GroupedConvFwdXdlBf16CompInstance0)
{
using Builder = ckb::ConvBuilder<FwdConvSignature, ConvFwdXdlBf16CompAlgorithm0, API_VERSION>;
EXPECT_EQ(
Builder::Instance::TypeString(),
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3<256, 256, 256, 32, Default, 32, 32, 4, 4, "
"8, 8, 8, 1, 1, BlkGemmPipelineScheduler: Intrawave, BlkGemmPipelineVersion: v4>");
};
} // namespace