diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp index cdd238f36a..c42d2abf80 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp @@ -32,52 +32,13 @@ constexpr ConvTraits instance_to_conv_traits() .gemm_padding = gemm_spec(), .conv_specialization = conv_spec(), .thread_block_size = InstTraits::kBlockSize, - .tile_dims = {.m = InstTraits::kMPerBlock, - .n = InstTraits::kNPerBlock, - .k = InstTraits::kKPerBlock}, - .a_tile_transfer = - {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, - .m_or_n = InstTraits::kMPerBlock, - .k1 = InstTraits::kAK1}, - .transfer_params = {.k1 = InstTraits::kAK1, - .thread_cluster_dims = InstTraits::kAThreadClusterLengths, - .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, - .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, - .src_scalar_per_vector = - InstTraits::kABlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kABlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, - .b_tile_transfer = - {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, - .m_or_n = InstTraits::kNPerBlock, - .k1 = InstTraits::kBK1}, - .transfer_params = {.k1 = InstTraits::kBK1, - .thread_cluster_dims = InstTraits::kBThreadClusterLengths, - .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, - .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, - .src_scalar_per_vector = - InstTraits::kBBlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kBBlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, - .warp_gemm = {.gemm_m = InstTraits::kMPerXDL, - .gemm_n = InstTraits::kNPerXDL, - .m_iter = InstTraits::kMXdlPerWave, - .n_iter = InstTraits::kNXdlPerWave}, - .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = - InstTraits::kCShuffleMXdlPerWavePerShuffle, - .n_gemms_per_shuffle = - InstTraits::kCShuffleNXdlPerWavePerShuffle}, - .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], - InstTraits::kCThreadClusterLengths[1], - InstTraits::kCThreadClusterLengths[2], - InstTraits::kCThreadClusterLengths[3]}, - .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}, - .pipeline_version = get_pipeline_version(), - .pipeline_scheduler = get_pipeline_scheduler(), + .tile_dims = conv_traits_data_tile(), + .a_tile_transfer = conv_traits_xdl_a_transfer_params(), + .b_tile_transfer = conv_traits_xdl_b_transfer_params(), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = conv_traits_xdl_c_tile_transfer(), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), }; } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp index 28c43c342f..b3d13fb337 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle_v3.hpp @@ -32,52 +32,13 @@ constexpr ConvTraits instance_to_conv_traits() .gemm_padding = gemm_spec(), .conv_specialization = conv_spec(), .thread_block_size = InstTraits::kBlockSize, - .tile_dims = {.m = InstTraits::kMPerBlock, - .n = InstTraits::kNPerBlock, - .k = InstTraits::kKPerBlock}, - .a_tile_transfer = - {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, - .m_or_n = InstTraits::kMPerBlock, - .k1 = InstTraits::kAK1}, - .transfer_params = {.k1 = InstTraits::kAK1, - .thread_cluster_dims = InstTraits::kAThreadClusterLengths, - .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, - .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, - .src_scalar_per_vector = - InstTraits::kABlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kABlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, - .b_tile_transfer = - {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, - .m_or_n = InstTraits::kNPerBlock, - .k1 = InstTraits::kBK1}, - .transfer_params = {.k1 = InstTraits::kBK1, - .thread_cluster_dims = InstTraits::kBThreadClusterLengths, - .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, - .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, - .src_scalar_per_vector = - InstTraits::kBBlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kBBlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, - .warp_gemm = {.gemm_m = InstTraits::kMPerXDL, - .gemm_n = InstTraits::kNPerXDL, - .m_iter = InstTraits::kMXdlPerWave, - .n_iter = InstTraits::kNXdlPerWave}, - .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = - InstTraits::kCShuffleMXdlPerWavePerShuffle, - .n_gemms_per_shuffle = - InstTraits::kCShuffleNXdlPerWavePerShuffle}, - .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], - InstTraits::kCThreadClusterLengths[1], - InstTraits::kCThreadClusterLengths[2], - InstTraits::kCThreadClusterLengths[3]}, - .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}, - .pipeline_version = get_pipeline_version(), - .pipeline_scheduler = get_pipeline_scheduler(), + .tile_dims = conv_traits_data_tile(), + .a_tile_transfer = conv_traits_xdl_a_transfer_params(), + .b_tile_transfer = conv_traits_xdl_b_transfer_params(), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = conv_traits_xdl_c_tile_transfer(), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), }; } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp index c4bed850eb..cf417ad959 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp @@ -32,52 +32,13 @@ constexpr ConvTraits instance_to_conv_traits() .gemm_padding = gemm_spec(), .conv_specialization = conv_spec(), .thread_block_size = InstTraits::kBlockSize, - .tile_dims = {.m = InstTraits::kMPerBlock, - .n = InstTraits::kNPerBlock, - .k = InstTraits::kKPerBlock}, - .a_tile_transfer = - {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, - .m_or_n = InstTraits::kMPerBlock, - .k1 = InstTraits::kAK1}, - .transfer_params = {.k1 = InstTraits::kAK1, - .thread_cluster_dims = InstTraits::kAThreadClusterLengths, - .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, - .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, - .src_scalar_per_vector = - InstTraits::kABlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kABlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}, - .b_tile_transfer = - {.tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, - .m_or_n = InstTraits::kNPerBlock, - .k1 = InstTraits::kBK1}, - .transfer_params = {.k1 = InstTraits::kBK1, - .thread_cluster_dims = InstTraits::kBThreadClusterLengths, - .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, - .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, - .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, - .src_scalar_per_vector = - InstTraits::kBBlockTransferSrcScalarPerVector, - .dst_scalar_per_vector_k1 = - InstTraits::kBBlockTransferDstScalarPerVectorK1, - .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}, - .warp_gemm = {.gemm_m = InstTraits::kMPerXDL, - .gemm_n = InstTraits::kNPerXDL, - .m_iter = InstTraits::kMXdlPerWave, - .n_iter = InstTraits::kNXdlPerWave}, - .c_tile_transfer = {.shuffle_params = {.m_gemms_per_shuffle = - InstTraits::kCShuffleMXdlPerWavePerShuffle, - .n_gemms_per_shuffle = - InstTraits::kCShuffleNXdlPerWavePerShuffle}, - .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], - InstTraits::kCThreadClusterLengths[1], - InstTraits::kCThreadClusterLengths[2], - InstTraits::kCThreadClusterLengths[3]}, - .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}, - .pipeline_version = get_pipeline_version(), - .pipeline_scheduler = get_pipeline_scheduler(), + .tile_dims = conv_traits_data_tile(), + .a_tile_transfer = conv_traits_xdl_a_transfer_params(), + .b_tile_transfer = conv_traits_xdl_b_transfer_params(), + .warp_gemm = conv_traits_xdl_warp_gemm_params(), + .c_tile_transfer = conv_traits_xdl_c_tile_transfer(), + .pipeline_version = get_pipeline_version(), + .pipeline_scheduler = get_pipeline_scheduler(), }; } diff --git a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp index 92885ad789..60d11bce6e 100644 --- a/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp +++ b/experimental/builder/include/ck_tile/builder/reflect/conv_traits_helpers.hpp @@ -640,4 +640,91 @@ constexpr auto get_pipeline_scheduler() } } +template +constexpr InputTileTransferInfo conv_traits_xdl_a_transfer_params() +{ + return InputTileTransferInfo{ + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kAK1, + .m_or_n = InstTraits::kMPerBlock, + .k1 = InstTraits::kAK1}, + .transfer_params = {.k1 = InstTraits::kAK1, + .thread_cluster_dims = InstTraits::kAThreadClusterLengths, + .thread_cluster_order = InstTraits::kAThreadClusterArrangeOrder, + .src_access_order = InstTraits::kABlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kABlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kABlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kABlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kABlockLdsExtraM)}}; +} + +template +constexpr InputTileTransferInfo conv_traits_xdl_b_transfer_params() +{ + return InputTileTransferInfo{ + .tile_dimensions = {.k0 = InstTraits::kKPerBlock / InstTraits::kBK1, + .m_or_n = InstTraits::kNPerBlock, + .k1 = InstTraits::kBK1}, + .transfer_params = {.k1 = InstTraits::kBK1, + .thread_cluster_dims = InstTraits::kBThreadClusterLengths, + .thread_cluster_order = InstTraits::kBThreadClusterArrangeOrder, + .src_access_order = InstTraits::kBBlockTransferSrcAccessOrder, + .src_vector_dim = InstTraits::kBBlockTransferSrcVectorDim, + .src_scalar_per_vector = InstTraits::kBBlockTransferSrcScalarPerVector, + .dst_scalar_per_vector_k1 = + InstTraits::kBBlockTransferDstScalarPerVectorK1, + .lds_padding = static_cast(InstTraits::kBBlockLdsExtraN)}}; +} + +template +constexpr OutputTileTransferInfo conv_traits_xdl_c_tile_transfer() +{ + return OutputTileTransferInfo{ + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMXdlPerWavePerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNXdlPerWavePerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; +} + +template +constexpr OutputTileTransferInfo conv_traits_wmma_c_tile_transfer() +{ + return OutputTileTransferInfo{ + .shuffle_params = {.m_gemms_per_shuffle = InstTraits::kCShuffleMRepeatPerShuffle, + .n_gemms_per_shuffle = InstTraits::kCShuffleNRepeatPerShuffle}, + .thread_cluster_dims = {InstTraits::kCThreadClusterLengths[0], + InstTraits::kCThreadClusterLengths[1], + InstTraits::kCThreadClusterLengths[2], + InstTraits::kCThreadClusterLengths[3]}, + .scalar_per_vector = InstTraits::kCBlockTransferScalarPerVector}; +} + +template +constexpr WarpGemmParams conv_traits_xdl_warp_gemm_params() +{ + return WarpGemmParams{.gemm_m = InstTraits::kMPerXDL, + .gemm_n = InstTraits::kNPerXDL, + .m_iter = InstTraits::kMXdlPerWave, + .n_iter = InstTraits::kNXdlPerWave}; +} + +template +constexpr WarpGemmParams conv_traits_wmma_warp_gemm_params() +{ + return WarpGemmParams{.gemm_m = InstTraits::kMPerWmma, + .gemm_n = InstTraits::kNPerWmma, + .m_iter = InstTraits::kMRepeat, + .n_iter = InstTraits::kNRepeat}; +} + +template +constexpr DataTileInfo conv_traits_data_tile() +{ + return DataTileInfo{ + .m = InstTraits::kMPerBlock, .n = InstTraits::kNPerBlock, .k = InstTraits::kKPerBlock}; +} + } // namespace ck_tile::reflect::conv