mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
[CK_BUILDER] Add conv factories for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle and DeviceGroupedConvFwdMultipleD_Wmma_CShuffle (#3138)
* Add device operation to conv signature. Use unions to hold conv layouts and device operations. * Add predicates for all device op instances. * Use the device op signature for validation. * Fix ckb CMakeLists.txt file for tests. * Fix building CK Builder instance traits after the introduction of direct load template parameter in CK. * Fix clang-formatting. * Add factory for DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle device op. * Add conv factory for DeviceGroupedConvFwdMultipleD_Wmma_CShuffle * Rename elements per wave per shuffle member in the epilogue concept. * clang-format * Add concepts and types for optional device op template parameters. * Add optional compute, direct load, and loop scheduler arguments to conv factory. * Add number of groups to merge template parameter. * clang-format.
This commit is contained in:
@@ -28,8 +28,8 @@ struct ThreadBlock
|
||||
};
|
||||
static_assert(ckb::ThreadBlockDescriptor<ThreadBlock>);
|
||||
|
||||
// Describe gridwise GEMM parameters.
|
||||
struct GridwiseGemm
|
||||
// Describe gridwise XDL GEMM parameters.
|
||||
struct GridwiseXdlGemm
|
||||
{
|
||||
// NOTE: ak1 and bk1 are difficult to verify in the kernel instantiation!!!
|
||||
size_t ak1 = 0;
|
||||
@@ -39,7 +39,26 @@ struct GridwiseGemm
|
||||
size_t m_xdl_per_wave = 0;
|
||||
size_t n_xdl_per_wave = 0;
|
||||
};
|
||||
static_assert(ckb::GridwiseGemmDescriptor<GridwiseGemm>);
|
||||
static_assert(ckb::GridwiseXdlGemmDescriptor<GridwiseXdlGemm>);
|
||||
|
||||
// Describe gridwise WMMA GEMM parameters.
|
||||
struct GridwiseWmmaGemm
|
||||
{
|
||||
size_t k1 = 0;
|
||||
size_t m_per_wmma = 0;
|
||||
size_t n_per_wmma = 0;
|
||||
size_t m_wmma_per_wave = 0;
|
||||
size_t n_wmma_per_wave = 0;
|
||||
GridwiseGemmPipelineVersion pipeline_version;
|
||||
};
|
||||
static_assert(ckb::GridwiseWmmaGemmDescriptor<GridwiseWmmaGemm>);
|
||||
|
||||
struct BlockGemm
|
||||
{
|
||||
BlockGemmPipelineVersion pipeline_version;
|
||||
BlockGemmPipelineScheduler scheduler;
|
||||
};
|
||||
static_assert(ckb::BlockGemmDescriptor<BlockGemm>);
|
||||
|
||||
// Describe Aand B block transfer thread cluster lengths.
|
||||
struct BlockTransfer
|
||||
@@ -72,8 +91,8 @@ static_assert(LdsTransferDescriptor<LdsTransfer>);
|
||||
|
||||
struct Epilogue
|
||||
{
|
||||
size_t m_xdl_per_wave_per_shuffle;
|
||||
size_t n_xdl_per_wave_per_shuffle;
|
||||
size_t m_per_wave_per_shuffle;
|
||||
size_t n_per_wave_per_shuffle;
|
||||
size_t scalar_per_vector;
|
||||
};
|
||||
static_assert(EpilogueDescriptor<Epilogue>);
|
||||
@@ -98,22 +117,101 @@ struct BlockTransferABC
|
||||
AccessOrder src_access_order_b;
|
||||
};
|
||||
|
||||
struct ConvAlgorithm
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseGemm gridwise_gemm;
|
||||
GridwiseXdlGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
BlockGemmPipelineVersion pipeline_version;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
BlockGemm block_gemm;
|
||||
};
|
||||
static_assert(ckb::ConvAlgorithmDescriptor<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesGridwiseGemm<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesBlockTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesSourceAccessOrder<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesGemmPipelineVersion<ConvAlgorithm>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm>);
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGridwiseXdlGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(
|
||||
ckb::SpecifiesBlockGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
static_assert(ckb::SpecifiesGemmSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3>);
|
||||
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseXdlGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
size_t num_groups_to_merge;
|
||||
LoopScheduler loop_scheduler;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGridwiseXdlGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(ckb::SpecifiesFwdConcSpecialization<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesNumPrefetchStages<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesNumGroupsToMerge<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>);
|
||||
|
||||
struct ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle
|
||||
{
|
||||
ThreadBlock thread_block;
|
||||
GridwiseWmmaGemm gridwise_gemm;
|
||||
BlockTransferABC block_transfer;
|
||||
ConvFwdSpecialization fwd_specialization;
|
||||
GemmSpecialization gemm_specialization;
|
||||
size_t num_gemm_k_prefetch_stages;
|
||||
LoopScheduler loop_scheduler;
|
||||
};
|
||||
static_assert(
|
||||
ckb::ConvAlgorithmDescriptor<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(ckb::SpecifiesThreadBlock<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGridwiseWmmaGemm<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesBlockTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(ckb::SpecifiesLdsTransfer<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(ckb::SpecifiesThreadClusterAccessOrder<
|
||||
ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesSourceAccessOrder<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesFwdConcSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesNumPrefetchStages<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesGemmSpecialization<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
static_assert(
|
||||
ckb::SpecifiesLoopScheduler<ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle>);
|
||||
|
||||
} // namespace ck_tile::builder::test
|
||||
|
||||
Reference in New Issue
Block a user