From f74e034ae9de98cf310c27f8464a15a024d6ff4b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ville=20Pietil=C3=A4?= <> Date: Fri, 9 Jan 2026 09:17:45 -0500 Subject: [PATCH] Adapt factories to warp GEMM and transfer parameters refactoring. --- .../builder/conv_algorithm_concepts.hpp | 20 +++++++------- ...onv_bwd_weight_multi_d_wmma_v3_factory.hpp | 17 +++++++----- .../conv_bwd_weight_multi_d_xdl_factory.hpp | 18 ++++++++----- ...v_bwd_weight_two_stage_wmma_v3_factory.hpp | 17 +++++++----- .../conv_bwd_weight_two_stage_xdl_factory.hpp | 18 ++++++++----- .../factory/conv_bwd_weight_wmma_factory.hpp | 17 +++++++----- .../conv_bwd_weight_wmma_v3_factory.hpp | 17 +++++++----- .../factory/conv_bwd_weight_xdl_factory.hpp | 19 ++++++++----- .../conv_bwd_weight_xdl_v3_factory.hpp | 18 ++++++++----- .../factory/conv_fwd_large_tensor_factory.hpp | 15 +++++------ .../builder/factory/conv_fwd_v3_factory.hpp | 21 +++++++-------- .../builder/factory/conv_fwd_wmma_factory.hpp | 17 +++++++----- .../builder/factory/conv_fwd_xdl_factory.hpp | 15 +++++------ .../helpers/ck/conv_block_transfer.hpp | 27 +++++++++++-------- .../factory/helpers/ck/conv_tuning_params.hpp | 2 +- .../test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp | 2 +- 16 files changed, 152 insertions(+), 108 deletions(-) diff --git a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp index cbc277a881..d036ab7ec6 100644 --- a/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp +++ b/experimental/builder/include/ck_tile/builder/conv_algorithm_concepts.hpp @@ -172,8 +172,8 @@ concept SpecifiesTileThreadBlock = requires { // Concept to check if a struct specifies warp GEMM info. template -concept SpecifiesWarpGemm = requires(T t) { - { t.warp_gemm } -> WarpGemmDescriptor; +concept SpecifiesWarpGemm = requires { + { T::warp_gemm } -> WarpGemmDescriptor; }; // Concept to check if a struct specifies convolution input and output block transfer info. @@ -212,8 +212,8 @@ concept SpecifiesLdsTransfer = requires(T t) { // Concept to check if a struct specifies thread cluster access order info. template concept SpecifiesThreadClusterAccessOrder = requires(T t) { - { T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor; - { T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor; + { T::transfer.a.thread_distribution_access_order } -> AccessOrderDescriptor; + { T::transfer.b.thread_distribution_access_order } -> AccessOrderDescriptor; }; // Concept to check if a struct specifies source access order info. @@ -341,15 +341,15 @@ concept SpecifiesMultipleDSupport = requires { }; template -concept SpecifiesXdl = requires { - { T::warp_gemm.matrix_instruction } -> std::convertible_to; - requires T::warp_gemm.matrix_instruction == MatrixInstructionType::XDL; +concept SpecifiesXdl = requires (T t){ + { t.warp_gemm.matrix_instruction } -> std::convertible_to; + { t.warp_gemm.matrix_instruction == MatrixInstructionType::XDL}; }; template -concept SpecifiesWmma = requires { - { T::warp_gemm.matrix_instruction } -> std::convertible_to; - requires T::warp_gemm.matrix_instruction == MatrixInstructionType::WMMA; +concept SpecifiesWmma = requires (T t){ + { t.warp_gemm.matrix_instruction } -> std::convertible_to; + { t.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA}; }; /******************************************** */ diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp index 6f5c679b59..234ba39829 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_wmma_v3_factory.hpp @@ -35,7 +35,7 @@ struct ConvBwdWeightMultiDWmmaV3Factory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -57,6 +57,11 @@ struct ConvBwdWeightMultiDWmmaV3Factory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3< @@ -78,11 +83,11 @@ struct ConvBwdWeightMultiDWmmaV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp index 9f76568ca8..843c4e0f90 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_multi_d_xdl_factory.hpp @@ -35,8 +35,7 @@ struct ConvBwdWeightMultiDXdlFactory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -53,6 +52,11 @@ struct ConvBwdWeightMultiDXdlFactory static_assert(AccessOrderLimits4D); static_assert(AccessOrderLimits4D); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle< SPATIAL_DIM, @@ -73,11 +77,11 @@ struct ConvBwdWeightMultiDXdlFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + GMEM_VECTOR_LOAD_SIZE, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp index 86c48fe322..48a15a1638 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_wmma_v3_factory.hpp @@ -35,7 +35,7 @@ struct ConvBwdWeightTwoStageWmmaV3Factory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -57,6 +57,11 @@ struct ConvBwdWeightTwoStageWmmaV3Factory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3< @@ -76,11 +81,11 @@ struct ConvBwdWeightTwoStageWmmaV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp index 9c37beae46..5eea36313f 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_two_stage_xdl_factory.hpp @@ -35,8 +35,7 @@ struct ConvBwdWeightTwoStageXdlFactory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -58,6 +57,11 @@ struct ConvBwdWeightTwoStageXdlFactory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle< SPATIAL_DIM, @@ -76,11 +80,11 @@ struct ConvBwdWeightTwoStageXdlFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + GMEM_VECTOR_LOAD_SIZE, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp index 32161a234a..8e22958ac1 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_factory.hpp @@ -35,7 +35,7 @@ struct ConvBwdWeightWmmaFactory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = internal::SetGridwiseGemmPipelineVersion(); static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); @@ -60,6 +60,11 @@ struct ConvBwdWeightWmmaFactory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle< SPATIAL_DIM, @@ -78,11 +83,11 @@ struct ConvBwdWeightWmmaFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp index baf84402c3..463749958a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_wmma_v3_factory.hpp @@ -35,7 +35,7 @@ struct ConvBwdWeightWmmaV3Factory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -57,6 +57,11 @@ struct ConvBwdWeightWmmaV3Factory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3< SPATIAL_DIM, @@ -75,11 +80,11 @@ struct ConvBwdWeightWmmaV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp index 91c19d2bd0..ba5fdb2c53 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_factory.hpp @@ -35,8 +35,7 @@ struct ConvBwdWeightXdlFactory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -53,6 +52,12 @@ struct ConvBwdWeightXdlFactory static_assert(AccessOrderLimits4D); static_assert(AccessOrderLimits4D); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle< SPATIAL_DIM, @@ -71,11 +76,11 @@ struct ConvBwdWeightXdlFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + GMEM_VECTOR_LOAD_SIZE, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp index f3edd0e6d9..ab4dbea2f4 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_bwd_weight_xdl_v3_factory.hpp @@ -35,8 +35,7 @@ struct ConvBwdWeightXdlV3Factory internal::SetBwdWeightConvSpecialization(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetBwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -58,6 +57,11 @@ struct ConvBwdWeightXdlV3Factory static_assert(AccessOrderLimits3D, "Invalid B source access order"); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3< SPATIAL_DIM, @@ -76,11 +80,11 @@ struct ConvBwdWeightXdlV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + GMEM_VECTOR_LOAD_SIZE, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp index 9cd56ad7ad..fdb95d602a 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_large_tensor_factory.hpp @@ -38,8 +38,7 @@ struct ConvFwdLargeTensorFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -79,12 +78,12 @@ struct ConvFwdLargeTensorFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + A_BLOCK_TRANSFER.global_memory_vector_load_size, + B_BLOCK_TRANSFER.global_memory_vector_load_size, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp index 7d889a0c01..a64929d158 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_v3_factory.hpp @@ -31,19 +31,18 @@ struct ConvFwdXdlV3Factory using Ops = internal::ElementwiseOps; using AlgorithmType = decltype(ALGORITHM); - static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load == - ALGORITHM.transfer.b.lds_transfer.is_direct_load, + static_assert(ALGORITHM.transfer.a.lds_transfer_params.is_direct_load == + ALGORITHM.transfer.b.lds_transfer_params.is_direct_load, "A and B block transfers must both be direct load or not."); - static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer.is_direct_load; + static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer_params.is_direct_load; static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization(); static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization(); static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION, .gemm_spec = GEMM_SPECIALIZATION}; static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -83,12 +82,12 @@ struct ConvFwdXdlV3Factory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + A_BLOCK_TRANSFER.global_memory_vector_load_size, + B_BLOCK_TRANSFER.global_memory_vector_load_size, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp index 3506f5d1a9..d52f684d8c 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_wmma_factory.hpp @@ -38,7 +38,7 @@ struct ConvFwdWmmaFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; + static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm; static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION = internal::SetGridwiseGemmPipelineVersion(); static constexpr auto A_BLOCK_TRANSFER = @@ -57,6 +57,11 @@ struct ConvFwdWmmaFactory static_assert(AccessOrderLimits3D); static_assert(AccessOrderLimits3D); + static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size == + B_BLOCK_TRANSFER.global_memory_vector_load_size, + "A nd B block transfer vector load size need to be the same"); + static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size; + // The forward convolution kernel class instance. using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle< SPATIAL_DIM, @@ -80,11 +85,11 @@ struct ConvFwdWmmaFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.k1, - GRIDWISE_GEMM.m_per_wmma, - GRIDWISE_GEMM.n_per_wmma, - GRIDWISE_GEMM.m_wmma_per_wave, - GRIDWISE_GEMM.n_wmma_per_wave, + GMEM_VECTOR_LOAD_SIZE, + WARP_GEMM.gemm_m_per_instruction, + WARP_GEMM.gemm_n_per_instruction, + WARP_GEMM.gemm_m_iters_per_wave, + WARP_GEMM.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp index 446ceceda2..eb2fdfad4d 100644 --- a/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/conv_fwd_xdl_factory.hpp @@ -38,8 +38,7 @@ struct ConvFwdXdlFactory static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler(); static constexpr auto BLOCK = internal::SetThreadBlockInfo(); - static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm; - static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params; + static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm; static constexpr auto A_BLOCK_TRANSFER = internal::SetFwdConvBlockTransfer(); static constexpr auto B_BLOCK_TRANSFER = @@ -79,12 +78,12 @@ struct ConvFwdXdlFactory BLOCK.per_block.m, BLOCK.per_block.n, BLOCK.per_block.k, - GRIDWISE_GEMM.ak1, - GRIDWISE_GEMM.bk1, - XDL_PARAMS.m_per_xdl, - XDL_PARAMS.n_per_xdl, - XDL_PARAMS.m_xdl_per_wave, - XDL_PARAMS.n_xdl_per_wave, + A_BLOCK_TRANSFER.global_memory_vector_load_size, + B_BLOCK_TRANSFER.global_memory_vector_load_size, + XDL_PARAMS.gemm_m_per_instruction, + XDL_PARAMS.gemm_n_per_instruction, + XDL_PARAMS.gemm_m_iters_per_wave, + XDL_PARAMS.gemm_n_iters_per_wave, to_sequence_v, to_sequence_v, to_sequence_v, diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp index 4ef2f533c9..dfe35355ad 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_block_transfer.hpp @@ -15,6 +15,7 @@ struct BlockTransfer ck::Array thread_cluster_dims{}; // k0, m, k1 ck::Array thread_cluster_order{}; ck::Array src_access_order{}; + size_t global_memory_vector_load_size = 0; size_t src_vector_dim = 0; size_t src_scalar_per_vector = 0; size_t lds_dst_scalar_per_vector = 0; @@ -28,6 +29,7 @@ struct BwdBlockTransfer ck::Array thread_cluster_dims{}; ck::Array thread_cluster_order{}; ck::Array src_access_order{}; + size_t global_memory_vector_load_size = 0; size_t src_vector_dim = 0; size_t src_scalar_per_vector = 0; size_t lds_dst_scalar_per_vector = 0; @@ -37,15 +39,16 @@ struct BwdBlockTransfer template constexpr BlockTransfer SetFwdConvBlockTransfer() { - auto& block_xfer = TRANSFER.block_transfer; - auto& block_order = TRANSFER.block_transfer_access_order; + auto& block_xfer = TRANSFER.thread_distribution; + auto& block_order = TRANSFER.thread_distribution_access_order; auto& src_order = TRANSFER.src_access_order; - auto& lds_cfg = TRANSFER.lds_transfer; + auto& lds_cfg = TRANSFER.lds_transfer_params; return BlockTransfer{ .thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1}, .thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, @@ -57,10 +60,10 @@ constexpr BlockTransfer SetFwdConvBlockTransfer() template constexpr auto SetBwdConvBlockTransfer() { - auto& block_xfer = TRANSFER.block_transfer; - auto& block_order = TRANSFER.block_transfer_access_order; + auto& block_xfer = TRANSFER.thread_distribution; + auto& block_order = TRANSFER.thread_distribution_access_order; auto& src_order = TRANSFER.src_access_order; - auto& lds_cfg = TRANSFER.lds_transfer; + auto& lds_cfg = TRANSFER.lds_transfer_params; constexpr auto array_length = block_order.order.size(); static_assert(block_order.order.size() == src_order.order.size(), @@ -74,6 +77,7 @@ constexpr auto SetBwdConvBlockTransfer() block_order.order[1], block_order.order[2]}, .src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]}, + .global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, @@ -95,6 +99,7 @@ constexpr auto SetBwdConvBlockTransfer() src_order.order[1], src_order.order[2], src_order.order[3]}, + .global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size, .src_vector_dim = lds_cfg.src_vector_dim, .src_scalar_per_vector = lds_cfg.src_scalar_per_vector, .lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector, @@ -119,17 +124,17 @@ struct CBlockTransfer template constexpr CBlockTransfer SetCBlockTransfer() { - auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_cluster_dims; + auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_distribution; auto& epilogue_config = ALGORITHM.transfer.c.epilogue; return CBlockTransfer{ .m_xdl_per_wave_per_shuffle = epilogue_config.m_xdl_per_wave_per_shuffle, .n_xdl_per_wave_per_shuffle = epilogue_config.n_per_wave_per_shuffle, .thread_cluster_dims = { - thread_cluster_dims.m_block, - thread_cluster_dims.m_wave_per_xdl, - thread_cluster_dims.n_block, - thread_cluster_dims.n_wave_per_xdl, + thread_cluster_dims.gemm_m_block_size, + thread_cluster_dims.gemm_m_per_block, + thread_cluster_dims.gemm_n_block_size, + thread_cluster_dims.gemm_n_per_block, }, .scalar_per_vector = epilogue_config.scalar_per_vector, }; diff --git a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp index 9ed1eebc3c..29cf3f8513 100644 --- a/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp +++ b/experimental/builder/include/ck_tile/builder/factory/helpers/ck/conv_tuning_params.hpp @@ -38,7 +38,7 @@ struct BlockGemmSpec template consteval BlockGemmSpec SetBlockGemm() { - constexpr auto& BG = ALGORITHM.block_gemm_pipeline; + constexpr auto& BG = ALGORITHM.gemm_pipeline; ck::BlockGemmPipelineScheduler scheduler; ck::BlockGemmPipelineVersion version; diff --git a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp index d087024913..8ee12a46ba 100644 --- a/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp +++ b/experimental/builder/test/conv/ck/test_ckb_conv_fwd_1d_i8.cpp @@ -33,7 +33,7 @@ TEST(FwdConvInstances, .with_gemm_config(GemmParams_Wmma_2x1_per_wave) .with_transfer(Transfer_4x32x1) .with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding) - .with_prefetch_config(1, PipelineScheduler::DEFAULT) + .with_prefetch_config(1, PipelineScheduler::INTRAWAVE) .with_num_conv_groups_to_merge(2) .with_gemm_pipeline(PipelineVersion::V1);