Adapt factories to warp GEMM and transfer parameters refactoring.

This commit is contained in:
Ville Pietilä
2026-01-09 09:17:45 -05:00
parent 3f0bac4e7b
commit f74e034ae9
16 changed files with 152 additions and 108 deletions

View File

@@ -172,8 +172,8 @@ concept SpecifiesTileThreadBlock = requires {
// Concept to check if a struct specifies warp GEMM info.
template <typename T>
concept SpecifiesWarpGemm = requires(T t) {
{ t.warp_gemm } -> WarpGemmDescriptor;
concept SpecifiesWarpGemm = requires {
{ T::warp_gemm } -> WarpGemmDescriptor;
};
// Concept to check if a struct specifies convolution input and output block transfer info.
@@ -212,8 +212,8 @@ concept SpecifiesLdsTransfer = requires(T t) {
// Concept to check if a struct specifies thread cluster access order info.
template <typename T>
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
{ T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor;
{ T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor;
{ T::transfer.a.thread_distribution_access_order } -> AccessOrderDescriptor;
{ T::transfer.b.thread_distribution_access_order } -> AccessOrderDescriptor;
};
// Concept to check if a struct specifies source access order info.
@@ -341,15 +341,15 @@ concept SpecifiesMultipleDSupport = requires {
};
template <typename T>
concept SpecifiesXdl = requires {
{ T::warp_gemm.matrix_instruction } -> std::convertible_to<MatrixInstructionType>;
requires T::warp_gemm.matrix_instruction == MatrixInstructionType::XDL;
concept SpecifiesXdl = requires (T t){
{ t.warp_gemm.matrix_instruction } -> std::convertible_to<MatrixInstructionType>;
{ t.warp_gemm.matrix_instruction == MatrixInstructionType::XDL};
};
template <typename T>
concept SpecifiesWmma = requires {
{ T::warp_gemm.matrix_instruction } -> std::convertible_to<MatrixInstructionType>;
requires T::warp_gemm.matrix_instruction == MatrixInstructionType::WMMA;
concept SpecifiesWmma = requires (T t){
{ t.warp_gemm.matrix_instruction } -> std::convertible_to<MatrixInstructionType>;
{ t.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA};
};
/******************************************** */

View File

@@ -35,7 +35,7 @@ struct ConvBwdWeightMultiDWmmaV3Factory
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -57,6 +57,11 @@ struct ConvBwdWeightMultiDWmmaV3Factory
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
B_BLOCK_TRANSFER.global_memory_vector_load_size,
"A nd B block transfer vector load size need to be the same");
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
// The forward convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3<
@@ -78,11 +83,11 @@ struct ConvBwdWeightMultiDWmmaV3Factory
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
GMEM_VECTOR_LOAD_SIZE,
WARP_GEMM.gemm_m_per_instruction,
WARP_GEMM.gemm_n_per_instruction,
WARP_GEMM.gemm_m_iters_per_wave,
WARP_GEMM.gemm_n_iters_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -35,8 +35,7 @@ struct ConvBwdWeightMultiDXdlFactory
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -53,6 +52,11 @@ struct ConvBwdWeightMultiDXdlFactory
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
B_BLOCK_TRANSFER.global_memory_vector_load_size,
"A nd B block transfer vector load size need to be the same");
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle<
SPATIAL_DIM,
@@ -73,11 +77,11 @@ struct ConvBwdWeightMultiDXdlFactory
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
GMEM_VECTOR_LOAD_SIZE,
XDL_PARAMS.gemm_m_per_instruction,
XDL_PARAMS.gemm_n_per_instruction,
XDL_PARAMS.gemm_m_iters_per_wave,
XDL_PARAMS.gemm_n_iters_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -35,7 +35,7 @@ struct ConvBwdWeightTwoStageWmmaV3Factory
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -57,6 +57,11 @@ struct ConvBwdWeightTwoStageWmmaV3Factory
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
B_BLOCK_TRANSFER.global_memory_vector_load_size,
"A nd B block transfer vector load size need to be the same");
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
// The forward convolution kernel class instance.
using Instance =
ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3<
@@ -76,11 +81,11 @@ struct ConvBwdWeightTwoStageWmmaV3Factory
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
GMEM_VECTOR_LOAD_SIZE,
WARP_GEMM.gemm_m_per_instruction,
WARP_GEMM.gemm_n_per_instruction,
WARP_GEMM.gemm_m_iters_per_wave,
WARP_GEMM.gemm_n_iters_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -35,8 +35,7 @@ struct ConvBwdWeightTwoStageXdlFactory
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -58,6 +57,11 @@ struct ConvBwdWeightTwoStageXdlFactory
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
B_BLOCK_TRANSFER.global_memory_vector_load_size,
"A nd B block transfer vector load size need to be the same");
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<
SPATIAL_DIM,
@@ -76,11 +80,11 @@ struct ConvBwdWeightTwoStageXdlFactory
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
GMEM_VECTOR_LOAD_SIZE,
XDL_PARAMS.gemm_m_per_instruction,
XDL_PARAMS.gemm_n_per_instruction,
XDL_PARAMS.gemm_m_iters_per_wave,
XDL_PARAMS.gemm_n_iters_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -35,7 +35,7 @@ struct ConvBwdWeightWmmaFactory
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm;
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
@@ -60,6 +60,11 @@ struct ConvBwdWeightWmmaFactory
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
B_BLOCK_TRANSFER.global_memory_vector_load_size,
"A nd B block transfer vector load size need to be the same");
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
SPATIAL_DIM,
@@ -78,11 +83,11 @@ struct ConvBwdWeightWmmaFactory
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
GMEM_VECTOR_LOAD_SIZE,
WARP_GEMM.gemm_m_per_instruction,
WARP_GEMM.gemm_n_per_instruction,
WARP_GEMM.gemm_m_iters_per_wave,
WARP_GEMM.gemm_n_iters_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -35,7 +35,7 @@ struct ConvBwdWeightWmmaV3Factory
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -57,6 +57,11 @@ struct ConvBwdWeightWmmaV3Factory
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
B_BLOCK_TRANSFER.global_memory_vector_load_size,
"A nd B block transfer vector load size need to be the same");
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3<
SPATIAL_DIM,
@@ -75,11 +80,11 @@ struct ConvBwdWeightWmmaV3Factory
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
GMEM_VECTOR_LOAD_SIZE,
WARP_GEMM.gemm_m_per_instruction,
WARP_GEMM.gemm_n_per_instruction,
WARP_GEMM.gemm_m_iters_per_wave,
WARP_GEMM.gemm_n_iters_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -35,8 +35,7 @@ struct ConvBwdWeightXdlFactory
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -53,6 +52,12 @@ struct ConvBwdWeightXdlFactory
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
B_BLOCK_TRANSFER.global_memory_vector_load_size,
"A nd B block transfer vector load size need to be the same");
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
SPATIAL_DIM,
@@ -71,11 +76,11 @@ struct ConvBwdWeightXdlFactory
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
GMEM_VECTOR_LOAD_SIZE,
XDL_PARAMS.gemm_m_per_instruction,
XDL_PARAMS.gemm_n_per_instruction,
XDL_PARAMS.gemm_m_iters_per_wave,
XDL_PARAMS.gemm_n_iters_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -35,8 +35,7 @@ struct ConvBwdWeightXdlV3Factory
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -58,6 +57,11 @@ struct ConvBwdWeightXdlV3Factory
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
"Invalid B source access order");
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
B_BLOCK_TRANSFER.global_memory_vector_load_size,
"A nd B block transfer vector load size need to be the same");
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
SPATIAL_DIM,
@@ -76,11 +80,11 @@ struct ConvBwdWeightXdlV3Factory
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
GMEM_VECTOR_LOAD_SIZE,
XDL_PARAMS.gemm_m_per_instruction,
XDL_PARAMS.gemm_n_per_instruction,
XDL_PARAMS.gemm_m_iters_per_wave,
XDL_PARAMS.gemm_n_iters_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -38,8 +38,7 @@ struct ConvFwdLargeTensorFactory
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -79,12 +78,12 @@ struct ConvFwdLargeTensorFactory
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
A_BLOCK_TRANSFER.global_memory_vector_load_size,
B_BLOCK_TRANSFER.global_memory_vector_load_size,
XDL_PARAMS.gemm_m_per_instruction,
XDL_PARAMS.gemm_n_per_instruction,
XDL_PARAMS.gemm_m_iters_per_wave,
XDL_PARAMS.gemm_n_iters_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -31,19 +31,18 @@ struct ConvFwdXdlV3Factory
using Ops = internal::ElementwiseOps<SIGNATURE>;
using AlgorithmType = decltype(ALGORITHM);
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
ALGORITHM.transfer.b.lds_transfer.is_direct_load,
static_assert(ALGORITHM.transfer.a.lds_transfer_params.is_direct_load ==
ALGORITHM.transfer.b.lds_transfer_params.is_direct_load,
"A and B block transfers must both be direct load or not.");
static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer.is_direct_load;
static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer_params.is_direct_load;
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
.gemm_spec = GEMM_SPECIALIZATION};
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -83,12 +82,12 @@ struct ConvFwdXdlV3Factory
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
A_BLOCK_TRANSFER.global_memory_vector_load_size,
B_BLOCK_TRANSFER.global_memory_vector_load_size,
XDL_PARAMS.gemm_m_per_instruction,
XDL_PARAMS.gemm_n_per_instruction,
XDL_PARAMS.gemm_m_iters_per_wave,
XDL_PARAMS.gemm_n_iters_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -38,7 +38,7 @@ struct ConvFwdWmmaFactory
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm;
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
static constexpr auto A_BLOCK_TRANSFER =
@@ -57,6 +57,11 @@ struct ConvFwdWmmaFactory
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
B_BLOCK_TRANSFER.global_memory_vector_load_size,
"A nd B block transfer vector load size need to be the same");
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
// The forward convolution kernel class instance.
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
SPATIAL_DIM,
@@ -80,11 +85,11 @@ struct ConvFwdWmmaFactory
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.k1,
GRIDWISE_GEMM.m_per_wmma,
GRIDWISE_GEMM.n_per_wmma,
GRIDWISE_GEMM.m_wmma_per_wave,
GRIDWISE_GEMM.n_wmma_per_wave,
GMEM_VECTOR_LOAD_SIZE,
WARP_GEMM.gemm_m_per_instruction,
WARP_GEMM.gemm_n_per_instruction,
WARP_GEMM.gemm_m_iters_per_wave,
WARP_GEMM.gemm_n_iters_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -38,8 +38,7 @@ struct ConvFwdXdlFactory
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
static constexpr auto A_BLOCK_TRANSFER =
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
static constexpr auto B_BLOCK_TRANSFER =
@@ -79,12 +78,12 @@ struct ConvFwdXdlFactory
BLOCK.per_block.m,
BLOCK.per_block.n,
BLOCK.per_block.k,
GRIDWISE_GEMM.ak1,
GRIDWISE_GEMM.bk1,
XDL_PARAMS.m_per_xdl,
XDL_PARAMS.n_per_xdl,
XDL_PARAMS.m_xdl_per_wave,
XDL_PARAMS.n_xdl_per_wave,
A_BLOCK_TRANSFER.global_memory_vector_load_size,
B_BLOCK_TRANSFER.global_memory_vector_load_size,
XDL_PARAMS.gemm_m_per_instruction,
XDL_PARAMS.gemm_n_per_instruction,
XDL_PARAMS.gemm_m_iters_per_wave,
XDL_PARAMS.gemm_n_iters_per_wave,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,

View File

@@ -15,6 +15,7 @@ struct BlockTransfer
ck::Array<size_t, 3> thread_cluster_dims{}; // k0, m, k1
ck::Array<size_t, 3> thread_cluster_order{};
ck::Array<size_t, 3> src_access_order{};
size_t global_memory_vector_load_size = 0;
size_t src_vector_dim = 0;
size_t src_scalar_per_vector = 0;
size_t lds_dst_scalar_per_vector = 0;
@@ -28,6 +29,7 @@ struct BwdBlockTransfer
ck::Array<size_t, ThreadSliceDim> thread_cluster_dims{};
ck::Array<size_t, ThreadSliceDim> thread_cluster_order{};
ck::Array<size_t, ThreadSliceDim> src_access_order{};
size_t global_memory_vector_load_size = 0;
size_t src_vector_dim = 0;
size_t src_scalar_per_vector = 0;
size_t lds_dst_scalar_per_vector = 0;
@@ -37,15 +39,16 @@ struct BwdBlockTransfer
template <auto TRANSFER>
constexpr BlockTransfer SetFwdConvBlockTransfer()
{
auto& block_xfer = TRANSFER.block_transfer;
auto& block_order = TRANSFER.block_transfer_access_order;
auto& block_xfer = TRANSFER.thread_distribution;
auto& block_order = TRANSFER.thread_distribution_access_order;
auto& src_order = TRANSFER.src_access_order;
auto& lds_cfg = TRANSFER.lds_transfer;
auto& lds_cfg = TRANSFER.lds_transfer_params;
return BlockTransfer{
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]},
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
.global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size,
.src_vector_dim = lds_cfg.src_vector_dim,
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
@@ -57,10 +60,10 @@ constexpr BlockTransfer SetFwdConvBlockTransfer()
template <auto TRANSFER>
constexpr auto SetBwdConvBlockTransfer()
{
auto& block_xfer = TRANSFER.block_transfer;
auto& block_order = TRANSFER.block_transfer_access_order;
auto& block_xfer = TRANSFER.thread_distribution;
auto& block_order = TRANSFER.thread_distribution_access_order;
auto& src_order = TRANSFER.src_access_order;
auto& lds_cfg = TRANSFER.lds_transfer;
auto& lds_cfg = TRANSFER.lds_transfer_params;
constexpr auto array_length = block_order.order.size();
static_assert(block_order.order.size() == src_order.order.size(),
@@ -74,6 +77,7 @@ constexpr auto SetBwdConvBlockTransfer()
block_order.order[1],
block_order.order[2]},
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
.global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size,
.src_vector_dim = lds_cfg.src_vector_dim,
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
@@ -95,6 +99,7 @@ constexpr auto SetBwdConvBlockTransfer()
src_order.order[1],
src_order.order[2],
src_order.order[3]},
.global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size,
.src_vector_dim = lds_cfg.src_vector_dim,
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
@@ -119,17 +124,17 @@ struct CBlockTransfer
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
constexpr CBlockTransfer SetCBlockTransfer()
{
auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_cluster_dims;
auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_distribution;
auto& epilogue_config = ALGORITHM.transfer.c.epilogue;
return CBlockTransfer{
.m_xdl_per_wave_per_shuffle = epilogue_config.m_xdl_per_wave_per_shuffle,
.n_xdl_per_wave_per_shuffle = epilogue_config.n_per_wave_per_shuffle,
.thread_cluster_dims =
{
thread_cluster_dims.m_block,
thread_cluster_dims.m_wave_per_xdl,
thread_cluster_dims.n_block,
thread_cluster_dims.n_wave_per_xdl,
thread_cluster_dims.gemm_m_block_size,
thread_cluster_dims.gemm_m_per_block,
thread_cluster_dims.gemm_n_block_size,
thread_cluster_dims.gemm_n_per_block,
},
.scalar_per_vector = epilogue_config.scalar_per_vector,
};

View File

@@ -38,7 +38,7 @@ struct BlockGemmSpec
template <ConvAlgorithmDescriptor auto ALGORITHM>
consteval BlockGemmSpec SetBlockGemm()
{
constexpr auto& BG = ALGORITHM.block_gemm_pipeline;
constexpr auto& BG = ALGORITHM.gemm_pipeline;
ck::BlockGemmPipelineScheduler scheduler;
ck::BlockGemmPipelineVersion version;

View File

@@ -33,7 +33,7 @@ TEST(FwdConvInstances,
.with_gemm_config(GemmParams_Wmma_2x1_per_wave)
.with_transfer(Transfer_4x32x1)
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
.with_prefetch_config(1, PipelineScheduler::INTRAWAVE)
.with_num_conv_groups_to_merge(2)
.with_gemm_pipeline(PipelineVersion::V1);