mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 11:47:48 +00:00
Adapt factories to warp GEMM and transfer parameters refactoring.
This commit is contained in:
@@ -172,8 +172,8 @@ concept SpecifiesTileThreadBlock = requires {
|
||||
|
||||
// Concept to check if a struct specifies warp GEMM info.
|
||||
template <typename T>
|
||||
concept SpecifiesWarpGemm = requires(T t) {
|
||||
{ t.warp_gemm } -> WarpGemmDescriptor;
|
||||
concept SpecifiesWarpGemm = requires {
|
||||
{ T::warp_gemm } -> WarpGemmDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies convolution input and output block transfer info.
|
||||
@@ -212,8 +212,8 @@ concept SpecifiesLdsTransfer = requires(T t) {
|
||||
// Concept to check if a struct specifies thread cluster access order info.
|
||||
template <typename T>
|
||||
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
|
||||
{ T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor;
|
||||
{ T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor;
|
||||
{ T::transfer.a.thread_distribution_access_order } -> AccessOrderDescriptor;
|
||||
{ T::transfer.b.thread_distribution_access_order } -> AccessOrderDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies source access order info.
|
||||
@@ -341,15 +341,15 @@ concept SpecifiesMultipleDSupport = requires {
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesXdl = requires {
|
||||
{ T::warp_gemm.matrix_instruction } -> std::convertible_to<MatrixInstructionType>;
|
||||
requires T::warp_gemm.matrix_instruction == MatrixInstructionType::XDL;
|
||||
concept SpecifiesXdl = requires (T t){
|
||||
{ t.warp_gemm.matrix_instruction } -> std::convertible_to<MatrixInstructionType>;
|
||||
{ t.warp_gemm.matrix_instruction == MatrixInstructionType::XDL};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
concept SpecifiesWmma = requires {
|
||||
{ T::warp_gemm.matrix_instruction } -> std::convertible_to<MatrixInstructionType>;
|
||||
requires T::warp_gemm.matrix_instruction == MatrixInstructionType::WMMA;
|
||||
concept SpecifiesWmma = requires (T t){
|
||||
{ t.warp_gemm.matrix_instruction } -> std::convertible_to<MatrixInstructionType>;
|
||||
{ t.warp_gemm.matrix_instruction == MatrixInstructionType::WMMA};
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
|
||||
@@ -35,7 +35,7 @@ struct ConvBwdWeightMultiDWmmaV3Factory
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -57,6 +57,11 @@ struct ConvBwdWeightMultiDWmmaV3Factory
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
|
||||
B_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
"A nd B block transfer vector load size need to be the same");
|
||||
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Wmma_CShuffleV3<
|
||||
@@ -78,11 +83,11 @@ struct ConvBwdWeightMultiDWmmaV3Factory
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
GMEM_VECTOR_LOAD_SIZE,
|
||||
WARP_GEMM.gemm_m_per_instruction,
|
||||
WARP_GEMM.gemm_n_per_instruction,
|
||||
WARP_GEMM.gemm_m_iters_per_wave,
|
||||
WARP_GEMM.gemm_n_iters_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -35,8 +35,7 @@ struct ConvBwdWeightMultiDXdlFactory
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -53,6 +52,11 @@ struct ConvBwdWeightMultiDXdlFactory
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
|
||||
B_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
"A nd B block transfer vector load size need to be the same");
|
||||
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightMultipleD_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
@@ -73,11 +77,11 @@ struct ConvBwdWeightMultiDXdlFactory
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
GMEM_VECTOR_LOAD_SIZE,
|
||||
XDL_PARAMS.gemm_m_per_instruction,
|
||||
XDL_PARAMS.gemm_n_per_instruction,
|
||||
XDL_PARAMS.gemm_m_iters_per_wave,
|
||||
XDL_PARAMS.gemm_n_iters_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -35,7 +35,7 @@ struct ConvBwdWeightTwoStageWmmaV3Factory
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -57,6 +57,11 @@ struct ConvBwdWeightTwoStageWmmaV3Factory
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
|
||||
B_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
"A nd B block transfer vector load size need to be the same");
|
||||
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance =
|
||||
ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Wmma_CShuffleV3<
|
||||
@@ -76,11 +81,11 @@ struct ConvBwdWeightTwoStageWmmaV3Factory
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
GMEM_VECTOR_LOAD_SIZE,
|
||||
WARP_GEMM.gemm_m_per_instruction,
|
||||
WARP_GEMM.gemm_n_per_instruction,
|
||||
WARP_GEMM.gemm_m_iters_per_wave,
|
||||
WARP_GEMM.gemm_n_iters_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -35,8 +35,7 @@ struct ConvBwdWeightTwoStageXdlFactory
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -58,6 +57,11 @@ struct ConvBwdWeightTwoStageXdlFactory
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
|
||||
B_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
"A nd B block transfer vector load size need to be the same");
|
||||
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeightTwoStage_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
@@ -76,11 +80,11 @@ struct ConvBwdWeightTwoStageXdlFactory
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
GMEM_VECTOR_LOAD_SIZE,
|
||||
XDL_PARAMS.gemm_m_per_instruction,
|
||||
XDL_PARAMS.gemm_n_per_instruction,
|
||||
XDL_PARAMS.gemm_m_iters_per_wave,
|
||||
XDL_PARAMS.gemm_n_iters_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -35,7 +35,7 @@ struct ConvBwdWeightWmmaFactory
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm;
|
||||
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
|
||||
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
@@ -60,6 +60,11 @@ struct ConvBwdWeightWmmaFactory
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
|
||||
B_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
"A nd B block transfer vector load size need to be the same");
|
||||
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
@@ -78,11 +83,11 @@ struct ConvBwdWeightWmmaFactory
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
GMEM_VECTOR_LOAD_SIZE,
|
||||
WARP_GEMM.gemm_m_per_instruction,
|
||||
WARP_GEMM.gemm_n_per_instruction,
|
||||
WARP_GEMM.gemm_m_iters_per_wave,
|
||||
WARP_GEMM.gemm_n_iters_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -35,7 +35,7 @@ struct ConvBwdWeightWmmaV3Factory
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -57,6 +57,11 @@ struct ConvBwdWeightWmmaV3Factory
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
|
||||
B_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
"A nd B block transfer vector load size need to be the same");
|
||||
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Wmma_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
@@ -75,11 +80,11 @@ struct ConvBwdWeightWmmaV3Factory
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
GMEM_VECTOR_LOAD_SIZE,
|
||||
WARP_GEMM.gemm_m_per_instruction,
|
||||
WARP_GEMM.gemm_n_per_instruction,
|
||||
WARP_GEMM.gemm_m_iters_per_wave,
|
||||
WARP_GEMM.gemm_n_iters_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -35,8 +35,7 @@ struct ConvBwdWeightXdlFactory
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -53,6 +52,12 @@ struct ConvBwdWeightXdlFactory
|
||||
static_assert(AccessOrderLimits4D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits4D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
|
||||
B_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
"A nd B block transfer vector load size need to be the same");
|
||||
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
|
||||
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
@@ -71,11 +76,11 @@ struct ConvBwdWeightXdlFactory
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
GMEM_VECTOR_LOAD_SIZE,
|
||||
XDL_PARAMS.gemm_m_per_instruction,
|
||||
XDL_PARAMS.gemm_n_per_instruction,
|
||||
XDL_PARAMS.gemm_m_iters_per_wave,
|
||||
XDL_PARAMS.gemm_n_iters_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -35,8 +35,7 @@ struct ConvBwdWeightXdlV3Factory
|
||||
internal::SetBwdWeightConvSpecialization<ALGORITHM>();
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetBwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -58,6 +57,11 @@ struct ConvBwdWeightXdlV3Factory
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>,
|
||||
"Invalid B source access order");
|
||||
|
||||
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
|
||||
B_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
"A nd B block transfer vector load size need to be the same");
|
||||
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvBwdWeight_Xdl_CShuffleV3<
|
||||
SPATIAL_DIM,
|
||||
@@ -76,11 +80,11 @@ struct ConvBwdWeightXdlV3Factory
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
GMEM_VECTOR_LOAD_SIZE,
|
||||
XDL_PARAMS.gemm_m_per_instruction,
|
||||
XDL_PARAMS.gemm_n_per_instruction,
|
||||
XDL_PARAMS.gemm_m_iters_per_wave,
|
||||
XDL_PARAMS.gemm_n_iters_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -38,8 +38,7 @@ struct ConvFwdLargeTensorFactory
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -79,12 +78,12 @@ struct ConvFwdLargeTensorFactory
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
A_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
B_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
XDL_PARAMS.gemm_m_per_instruction,
|
||||
XDL_PARAMS.gemm_n_per_instruction,
|
||||
XDL_PARAMS.gemm_m_iters_per_wave,
|
||||
XDL_PARAMS.gemm_n_iters_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -31,19 +31,18 @@ struct ConvFwdXdlV3Factory
|
||||
using Ops = internal::ElementwiseOps<SIGNATURE>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
|
||||
ALGORITHM.transfer.b.lds_transfer.is_direct_load,
|
||||
static_assert(ALGORITHM.transfer.a.lds_transfer_params.is_direct_load ==
|
||||
ALGORITHM.transfer.b.lds_transfer_params.is_direct_load,
|
||||
"A and B block transfers must both be direct load or not.");
|
||||
|
||||
static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer.is_direct_load;
|
||||
static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer_params.is_direct_load;
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION = internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION = internal::SetGemmSpecialization<ALGORITHM>();
|
||||
static constexpr internal::ConvSpec SPECIALIZATION{.conv_spec = FWD_CONV_SPECIALIZATION,
|
||||
.gemm_spec = GEMM_SPECIALIZATION};
|
||||
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -83,12 +82,12 @@ struct ConvFwdXdlV3Factory
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
A_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
B_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
XDL_PARAMS.gemm_m_per_instruction,
|
||||
XDL_PARAMS.gemm_n_per_instruction,
|
||||
XDL_PARAMS.gemm_m_iters_per_wave,
|
||||
XDL_PARAMS.gemm_n_iters_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -38,7 +38,7 @@ struct ConvFwdWmmaFactory
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto WARP_GEMM = ALGORITHM.warp_gemm;
|
||||
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
|
||||
internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
@@ -57,6 +57,11 @@ struct ConvFwdWmmaFactory
|
||||
static_assert(AccessOrderLimits3D<A_BLOCK_TRANSFER.src_access_order>);
|
||||
static_assert(AccessOrderLimits3D<B_BLOCK_TRANSFER.src_access_order>);
|
||||
|
||||
static_assert(A_BLOCK_TRANSFER.global_memory_vector_load_size ==
|
||||
B_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
"A nd B block transfer vector load size need to be the same");
|
||||
static constexpr size_t GMEM_VECTOR_LOAD_SIZE = A_BLOCK_TRANSFER.global_memory_vector_load_size;
|
||||
|
||||
// The forward convolution kernel class instance.
|
||||
using Instance = ck::tensor_operation::device::DeviceGroupedConvFwdMultipleD_Wmma_CShuffle<
|
||||
SPATIAL_DIM,
|
||||
@@ -80,11 +85,11 @@ struct ConvFwdWmmaFactory
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.k1,
|
||||
GRIDWISE_GEMM.m_per_wmma,
|
||||
GRIDWISE_GEMM.n_per_wmma,
|
||||
GRIDWISE_GEMM.m_wmma_per_wave,
|
||||
GRIDWISE_GEMM.n_wmma_per_wave,
|
||||
GMEM_VECTOR_LOAD_SIZE,
|
||||
WARP_GEMM.gemm_m_per_instruction,
|
||||
WARP_GEMM.gemm_n_per_instruction,
|
||||
WARP_GEMM.gemm_m_iters_per_wave,
|
||||
WARP_GEMM.gemm_n_iters_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -38,8 +38,7 @@ struct ConvFwdXdlFactory
|
||||
|
||||
static constexpr auto LOOP_SCHEDULER = internal::SetLoopScheduler<ALGORITHM>();
|
||||
static constexpr auto BLOCK = internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto XDL_PARAMS = GRIDWISE_GEMM.xdl_params;
|
||||
static constexpr auto XDL_PARAMS = ALGORITHM.warp_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
@@ -79,12 +78,12 @@ struct ConvFwdXdlFactory
|
||||
BLOCK.per_block.m,
|
||||
BLOCK.per_block.n,
|
||||
BLOCK.per_block.k,
|
||||
GRIDWISE_GEMM.ak1,
|
||||
GRIDWISE_GEMM.bk1,
|
||||
XDL_PARAMS.m_per_xdl,
|
||||
XDL_PARAMS.n_per_xdl,
|
||||
XDL_PARAMS.m_xdl_per_wave,
|
||||
XDL_PARAMS.n_xdl_per_wave,
|
||||
A_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
B_BLOCK_TRANSFER.global_memory_vector_load_size,
|
||||
XDL_PARAMS.gemm_m_per_instruction,
|
||||
XDL_PARAMS.gemm_n_per_instruction,
|
||||
XDL_PARAMS.gemm_m_iters_per_wave,
|
||||
XDL_PARAMS.gemm_n_iters_per_wave,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_dims>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.thread_cluster_order>,
|
||||
to_sequence_v<A_BLOCK_TRANSFER.src_access_order>,
|
||||
|
||||
@@ -15,6 +15,7 @@ struct BlockTransfer
|
||||
ck::Array<size_t, 3> thread_cluster_dims{}; // k0, m, k1
|
||||
ck::Array<size_t, 3> thread_cluster_order{};
|
||||
ck::Array<size_t, 3> src_access_order{};
|
||||
size_t global_memory_vector_load_size = 0;
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
@@ -28,6 +29,7 @@ struct BwdBlockTransfer
|
||||
ck::Array<size_t, ThreadSliceDim> thread_cluster_dims{};
|
||||
ck::Array<size_t, ThreadSliceDim> thread_cluster_order{};
|
||||
ck::Array<size_t, ThreadSliceDim> src_access_order{};
|
||||
size_t global_memory_vector_load_size = 0;
|
||||
size_t src_vector_dim = 0;
|
||||
size_t src_scalar_per_vector = 0;
|
||||
size_t lds_dst_scalar_per_vector = 0;
|
||||
@@ -37,15 +39,16 @@ struct BwdBlockTransfer
|
||||
template <auto TRANSFER>
|
||||
constexpr BlockTransfer SetFwdConvBlockTransfer()
|
||||
{
|
||||
auto& block_xfer = TRANSFER.block_transfer;
|
||||
auto& block_order = TRANSFER.block_transfer_access_order;
|
||||
auto& block_xfer = TRANSFER.thread_distribution;
|
||||
auto& block_order = TRANSFER.thread_distribution_access_order;
|
||||
auto& src_order = TRANSFER.src_access_order;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer_params;
|
||||
|
||||
return BlockTransfer{
|
||||
.thread_cluster_dims = {block_xfer.k0, block_xfer.m_n, block_xfer.k1},
|
||||
.thread_cluster_order = {block_order.order[0], block_order.order[1], block_order.order[2]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
|
||||
.global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size,
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
@@ -57,10 +60,10 @@ constexpr BlockTransfer SetFwdConvBlockTransfer()
|
||||
template <auto TRANSFER>
|
||||
constexpr auto SetBwdConvBlockTransfer()
|
||||
{
|
||||
auto& block_xfer = TRANSFER.block_transfer;
|
||||
auto& block_order = TRANSFER.block_transfer_access_order;
|
||||
auto& block_xfer = TRANSFER.thread_distribution;
|
||||
auto& block_order = TRANSFER.thread_distribution_access_order;
|
||||
auto& src_order = TRANSFER.src_access_order;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer;
|
||||
auto& lds_cfg = TRANSFER.lds_transfer_params;
|
||||
|
||||
constexpr auto array_length = block_order.order.size();
|
||||
static_assert(block_order.order.size() == src_order.order.size(),
|
||||
@@ -74,6 +77,7 @@ constexpr auto SetBwdConvBlockTransfer()
|
||||
block_order.order[1],
|
||||
block_order.order[2]},
|
||||
.src_access_order = {src_order.order[0], src_order.order[1], src_order.order[2]},
|
||||
.global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size,
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
@@ -95,6 +99,7 @@ constexpr auto SetBwdConvBlockTransfer()
|
||||
src_order.order[1],
|
||||
src_order.order[2],
|
||||
src_order.order[3]},
|
||||
.global_memory_vector_load_size = lds_cfg.global_memory_vector_load_size,
|
||||
.src_vector_dim = lds_cfg.src_vector_dim,
|
||||
.src_scalar_per_vector = lds_cfg.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = lds_cfg.lds_dst_scalar_per_vector,
|
||||
@@ -119,17 +124,17 @@ struct CBlockTransfer
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr CBlockTransfer SetCBlockTransfer()
|
||||
{
|
||||
auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_cluster_dims;
|
||||
auto& thread_cluster_dims = ALGORITHM.transfer.c.thread_distribution;
|
||||
auto& epilogue_config = ALGORITHM.transfer.c.epilogue;
|
||||
return CBlockTransfer{
|
||||
.m_xdl_per_wave_per_shuffle = epilogue_config.m_xdl_per_wave_per_shuffle,
|
||||
.n_xdl_per_wave_per_shuffle = epilogue_config.n_per_wave_per_shuffle,
|
||||
.thread_cluster_dims =
|
||||
{
|
||||
thread_cluster_dims.m_block,
|
||||
thread_cluster_dims.m_wave_per_xdl,
|
||||
thread_cluster_dims.n_block,
|
||||
thread_cluster_dims.n_wave_per_xdl,
|
||||
thread_cluster_dims.gemm_m_block_size,
|
||||
thread_cluster_dims.gemm_m_per_block,
|
||||
thread_cluster_dims.gemm_n_block_size,
|
||||
thread_cluster_dims.gemm_n_per_block,
|
||||
},
|
||||
.scalar_per_vector = epilogue_config.scalar_per_vector,
|
||||
};
|
||||
|
||||
@@ -38,7 +38,7 @@ struct BlockGemmSpec
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
consteval BlockGemmSpec SetBlockGemm()
|
||||
{
|
||||
constexpr auto& BG = ALGORITHM.block_gemm_pipeline;
|
||||
constexpr auto& BG = ALGORITHM.gemm_pipeline;
|
||||
|
||||
ck::BlockGemmPipelineScheduler scheduler;
|
||||
ck::BlockGemmPipelineVersion version;
|
||||
|
||||
@@ -33,7 +33,7 @@ TEST(FwdConvInstances,
|
||||
.with_gemm_config(GemmParams_Wmma_2x1_per_wave)
|
||||
.with_transfer(Transfer_4x32x1)
|
||||
.with_fwd_specializations(ConvSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
|
||||
.with_prefetch_config(1, PipelineScheduler::DEFAULT)
|
||||
.with_prefetch_config(1, PipelineScheduler::INTRAWAVE)
|
||||
.with_num_conv_groups_to_merge(2)
|
||||
.with_gemm_pipeline(PipelineVersion::V1);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user