[CK_BUILDER] Add conv factories for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle and DeviceGroupedConvFwdMultipleD_Wmma_CShuffle (#3138)

* Add device operation to conv signature. Use unions to hold conv layouts and device operations.

* Add predicates for all device op instances.

* Use the device op signature for validation.

* Fix ckb CMakeLists.txt file for tests.

* Fix building CK Builder instance traits after the introduction of direct load template parameter in CK.

* Fix clang-formatting.

* Add factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device op.

* Add conv factory for  DeviceGroupedConvFwdMultipleD_Wmma_CShuffle

* Rename elements per wave per shuffle member in the epilogue concept.

* clang-format

* Add concepts and types for optional device op template parameters.

* Add optional compute, direct load, and loop scheduler arguments to conv factory.

* Add number of groups to merge template parameter.

* clang-format.
This commit is contained in:
Ville Pietilä
2025-11-03 09:03:25 +02:00
committed by GitHub
parent 16e85cf179
commit 3ae3992c18
16 changed files with 986 additions and 168 deletions

View File

@@ -28,8 +28,8 @@ struct ThreadBlock
};
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
// Describe gridwise GEMM parameters.
struct GridwiseGemm
// Describe gridwise XDL GEMM parameters.
struct GridwiseXdlGemm
{
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
size_t ak1 = 0;
@@ -39,7 +39,26 @@ struct GridwiseGemm
size_t m_xdl_per_wave = 0;
size_t n_xdl_per_wave = 0;
};
static_assert(ckb::GridwiseGemmDescriptor<GridwiseGemm>);
static_assert(ckb::GridwiseXdlGemmDescriptor<GridwiseXdlGemm>);
// Describe gridwise WMMA GEMM parameters.
struct GridwiseWmmaGemm
{
size_t k1 = 0;
size_t m_per_wmma = 0;
size_t n_per_wmma = 0;
size_t m_wmma_per_wave = 0;
size_t n_wmma_per_wave = 0;
GridwiseGemmPipelineVersion pipeline_version;
};
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemm>);
struct BlockGemm
{
BlockGemmPipelineVersion pipeline_version;
BlockGemmPipelineScheduler scheduler;
};
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
// Describe Aand B block transfer thread cluster lengths.
struct BlockTransfer
@@ -72,8 +91,8 @@ static_assert(LdsTransferDescriptor<LdsTransfer>);
struct Epilogue
{
size_t m_xdl_per_wave_per_shuffle;
size_t n_xdl_per_wave_per_shuffle;
size_t m_per_wave_per_shuffle;
size_t n_per_wave_per_shuffle;
size_t scalar_per_vector;
};
static_assert(EpilogueDescriptor<Epilogue>);
@@ -98,22 +117,101 @@ struct BlockTransferABC
AccessOrder src_access_order_b;
};
struct ConvAlgorithm
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
{
ThreadBlock thread_block;
GridwiseGemm gridwise_gemm;
GridwiseXdlGemm gridwise_gemm;
BlockTransferABC block_transfer;
BlockGemmPipelineVersion pipeline_version;
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
BlockGemm block_gemm;
};
static_assert(ckb::ConvAlgorithmDescriptor<ConvAlgorithm>);
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm>);
static_assert(ckb::SpecifiesGridwiseGemm<ConvAlgorithm>);
static_assert(ckb::SpecifiesBlockTransfer<ConvAlgorithm>);
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm>);
static_assert(ckb::SpecifiesThreadClusterAccessOrder<ConvAlgorithm>);
static_assert(ckb::SpecifiesSourceAccessOrder<ConvAlgorithm>);
static_assert(ckb::SpecifiesGemmPipelineVersion<ConvAlgorithm>);
static_assert(ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm>);
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesGridwiseXdlGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(ckb::SpecifiesFwdConcSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(
ckb::SpecifiesBlockGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
static_assert(ckb::SpecifiesGemmSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
{
ThreadBlock thread_block;
GridwiseXdlGemm gridwise_gemm;
BlockTransferABC block_transfer;
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
size_t num_gemm_k_prefetch_stages;
size_t num_groups_to_merge;
LoopScheduler loop_scheduler;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesGridwiseXdlGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(ckb::SpecifiesFwdConcSpecialization<
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesNumPrefetchStages<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
static_assert(
ckb::SpecifiesNumGroupsToMerge<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
{
ThreadBlock thread_block;
GridwiseWmmaGemm gridwise_gemm;
BlockTransferABC block_transfer;
ConvFwdSpecialization fwd_specialization;
GemmSpecialization gemm_specialization;
size_t num_gemm_k_prefetch_stages;
LoopScheduler loop_scheduler;
};
static_assert(
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesGridwiseWmmaGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesNumPrefetchStages<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
static_assert(
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
} // namespace ck_tile::builder::test