mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-04 05:31:24 +00:00
ck-builder: group transfer operations per tensor (#3217)
Grouping transfer operations per tensor makes it easier to constrain on and operate with the transfer operations. As an example, we can now deduplicate the logic for translating the transfer operations from the ck-builder interface to the old ck interface for the A and B tensors.
This commit is contained in:
@@ -125,31 +125,31 @@ concept SpecifiesGridwiseWmmaGemm = requires {
|
||||
// Concept to check if a struct specifies convolution input and output block transfer info.
|
||||
template <typename T>
|
||||
concept SpecifiesBlockTransfer = requires(T t) {
|
||||
{ T::block_transfer.block_transfer_a } -> BlockTransferDescriptor;
|
||||
{ T::block_transfer.block_transfer_b } -> BlockTransferDescriptor;
|
||||
{ T::block_transfer.thread_cluster_dims_c } -> ThreadClusterDescriptor;
|
||||
{ T::transfer.a.block_transfer } -> BlockTransferDescriptor;
|
||||
{ T::transfer.b.block_transfer } -> BlockTransferDescriptor;
|
||||
{ T::transfer.c.thread_cluster_dims } -> ThreadClusterDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies LDS transfer info for tensors A, B, and C.
|
||||
template <typename T>
|
||||
concept SpecifiesLdsTransfer = requires(T t) {
|
||||
{ T::block_transfer.lds_transfer_a } -> LdsTransferDescriptor;
|
||||
{ T::block_transfer.lds_transfer_b } -> LdsTransferDescriptor;
|
||||
{ T::block_transfer.epilogue_c } -> EpilogueDescriptor;
|
||||
{ T::transfer.a.lds_transfer } -> LdsTransferDescriptor;
|
||||
{ T::transfer.b.lds_transfer } -> LdsTransferDescriptor;
|
||||
{ T::transfer.c.epilogue } -> EpilogueDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies thread cluster access order info.
|
||||
template <typename T>
|
||||
concept SpecifiesThreadClusterAccessOrder = requires(T t) {
|
||||
{ T::block_transfer.block_transfer_access_order_a } -> AccessOrderDescriptor;
|
||||
{ T::block_transfer.block_transfer_access_order_b } -> AccessOrderDescriptor;
|
||||
{ T::transfer.a.block_transfer_access_order } -> AccessOrderDescriptor;
|
||||
{ T::transfer.b.block_transfer_access_order } -> AccessOrderDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if a struct specifies source access order info.
|
||||
template <typename T>
|
||||
concept SpecifiesSourceAccessOrder = requires(T t) {
|
||||
{ T::block_transfer.src_access_order_a } -> AccessOrderDescriptor;
|
||||
{ T::block_transfer.src_access_order_b } -> AccessOrderDescriptor;
|
||||
{ T::transfer.a.src_access_order } -> AccessOrderDescriptor;
|
||||
{ T::transfer.b.src_access_order } -> AccessOrderDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if struct specifies block GEMM.
|
||||
@@ -246,14 +246,14 @@ concept SpecifiesDlThreadCluster = requires {
|
||||
// Concept to check if algorithm specifies DL block transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlBlockTransfer = requires {
|
||||
{ T::block_transfer_a } -> DlBlockTransferDescriptor;
|
||||
{ T::block_transfer_b } -> DlBlockTransferDescriptor;
|
||||
{ T::transfer.a.block_transfer } -> DlBlockTransferDescriptor;
|
||||
{ T::transfer.b.block_transfer } -> DlBlockTransferDescriptor;
|
||||
};
|
||||
|
||||
// Concept to check if algorithm specifies DL C thread transfer
|
||||
template <typename T>
|
||||
concept SpecifiesDlEpilogue = requires {
|
||||
{ T::epilogue_c } -> DlEpilogueDescriptor;
|
||||
{ T::transfer.c.epilogue } -> DlEpilogueDescriptor;
|
||||
};
|
||||
|
||||
/******************************************** */
|
||||
|
||||
@@ -25,8 +25,7 @@
|
||||
// `constexpr` Helper Functions:
|
||||
// - SetThreadBlockInfo: Determines thread block dimensions and tile sizes.
|
||||
// - SetConvTuningInfo: Sets XDL and AK1/BK1 tuning parameters.
|
||||
// - SetFwdConvABlockTransfer: Configures A tensor block transfer parameters.
|
||||
// - SetFwdConvBBlockTransfer: Configures B tensor block transfer parameters.
|
||||
// - SetFwdConvBlockTransfer: Configures A/B tensor block transfer parameters.
|
||||
// - SetCBlockTransfer: Configures C tensor block transfer parameters.
|
||||
// - SetBlockGemmPipelineVersion: Maps pipeline version enum to CK types.
|
||||
//
|
||||
@@ -381,32 +380,13 @@ struct BlockTransfer
|
||||
bool lds_padding = false;
|
||||
};
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockTransfer SetFwdConvABlockTransfer()
|
||||
template <auto TRANSFER>
|
||||
constexpr BlockTransfer SetFwdConvBlockTransfer()
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_a;
|
||||
constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_a;
|
||||
constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_a;
|
||||
constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_a;
|
||||
|
||||
BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1},
|
||||
.thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]},
|
||||
.src_access_order = {SAO.order[0], SAO.order[1], SAO.order[2]},
|
||||
.src_vector_dim = LDS.src_vector_dim,
|
||||
.src_scalar_per_vector = LDS.src_scalar_per_vector,
|
||||
.lds_dst_scalar_per_vector = LDS.lds_dst_scalar_per_vector,
|
||||
.is_direct_load = LDS.is_direct_load,
|
||||
.lds_padding = LDS.lds_padding};
|
||||
return block_transfer;
|
||||
}
|
||||
|
||||
template <ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr BlockTransfer SetFwdConvBBlockTransfer()
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.block_transfer_b;
|
||||
constexpr auto& TCO = ALGORITHM.block_transfer.block_transfer_access_order_b;
|
||||
constexpr auto& SAO = ALGORITHM.block_transfer.src_access_order_b;
|
||||
constexpr auto& LDS = ALGORITHM.block_transfer.lds_transfer_b;
|
||||
constexpr auto& TCL = TRANSFER.block_transfer;
|
||||
constexpr auto& TCO = TRANSFER.block_transfer_access_order;
|
||||
constexpr auto& SAO = TRANSFER.src_access_order;
|
||||
constexpr auto& LDS = TRANSFER.lds_transfer;
|
||||
|
||||
BlockTransfer block_transfer{.thread_cluster_dims = {TCL.k0, TCL.m_n, TCL.k1},
|
||||
.thread_cluster_order = {TCO.order[0], TCO.order[1], TCO.order[2]},
|
||||
@@ -431,8 +411,8 @@ struct CBlockTransfer
|
||||
template <ConvSignatureDescriptor auto SIGNATURE, ConvAlgorithmDescriptor auto ALGORITHM>
|
||||
constexpr CBlockTransfer SetCBlockTransfer()
|
||||
{
|
||||
constexpr auto& TCL = ALGORITHM.block_transfer.thread_cluster_dims_c;
|
||||
constexpr auto& EPC = ALGORITHM.block_transfer.epilogue_c;
|
||||
constexpr auto& TCL = ALGORITHM.transfer.c.thread_cluster_dims;
|
||||
constexpr auto& EPC = ALGORITHM.transfer.c.epilogue;
|
||||
CBlockTransfer block_transfer{.m_per_wave_per_shuffle = EPC.m_per_wave_per_shuffle,
|
||||
.n_per_wave_per_shuffle = EPC.n_per_wave_per_shuffle,
|
||||
.thread_cluster_dims =
|
||||
@@ -568,11 +548,11 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using Ops = factory_internal::ElementwiseOps<get_elementwise_operation<SIGNATURE>()>;
|
||||
using AlgorithmType = decltype(ALGORITHM);
|
||||
|
||||
static_assert(ALGORITHM.block_transfer.lds_transfer_a.is_direct_load ==
|
||||
ALGORITHM.block_transfer.lds_transfer_b.is_direct_load,
|
||||
static_assert(ALGORITHM.transfer.a.lds_transfer.is_direct_load ==
|
||||
ALGORITHM.transfer.b.lds_transfer.is_direct_load,
|
||||
"A and B block transfers must both be direct load or not.");
|
||||
|
||||
static constexpr bool IS_DIRECT_LOAD = ALGORITHM.block_transfer.lds_transfer_a.is_direct_load;
|
||||
static constexpr bool IS_DIRECT_LOAD = ALGORITHM.transfer.a.lds_transfer.is_direct_load;
|
||||
static constexpr auto FWD_CONV_SPECIALIZATION =
|
||||
factory_internal::SetFwdConvSpecialization<ALGORITHM>();
|
||||
static constexpr auto GEMM_SPECIALIZATION =
|
||||
@@ -583,9 +563,9 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
factory_internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
factory_internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
static constexpr auto BLOCK_GEMM = factory_internal::SetBlockGemm<ALGORITHM>();
|
||||
@@ -681,9 +661,9 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
factory_internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
factory_internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
@@ -780,9 +760,9 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
static constexpr auto GRIDWISE_GEMM_PIPELINE_VERSION =
|
||||
factory_internal::SetGridwiseGemmPipelineVersion<ALGORITHM>();
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvABlockTransfer<ALGORITHM>();
|
||||
factory_internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvBBlockTransfer<ALGORITHM>();
|
||||
factory_internal::SetFwdConvBlockTransfer<ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, ALGORITHM>();
|
||||
|
||||
@@ -884,7 +864,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
using M1N1ThreadClusterN1Xs = to_sequence_v<DL_CLUSTER.n1_xs>;
|
||||
|
||||
// A Block Transfer from descriptor - K0_M0_M1_K1 tensor format
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.block_transfer_a;
|
||||
static constexpr auto DL_A_TRANSFER = ALGORITHM.transfer.a.block_transfer;
|
||||
using ABlockTransferThreadSliceLengths_K0_M0_M1_K1 =
|
||||
to_sequence_v<DL_A_TRANSFER.thread_slice_lengths>;
|
||||
using ABlockTransferThreadClusterLengths_K0_M0_M1_K1 =
|
||||
@@ -900,7 +880,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
to_sequence_v<DL_A_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// B Block Transfer from descriptor - K0_N0_N1_K1 tensor format
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.block_transfer_b;
|
||||
static constexpr auto DL_B_TRANSFER = ALGORITHM.transfer.b.block_transfer;
|
||||
using BBlockTransferThreadSliceLengths_K0_N0_N1_K1 =
|
||||
to_sequence_v<DL_B_TRANSFER.thread_slice_lengths>;
|
||||
using BBlockTransferThreadClusterLengths_K0_N0_N1_K1 =
|
||||
@@ -916,7 +896,7 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
to_sequence_v<DL_B_TRANSFER.dst_vector_tensor_lengths>;
|
||||
|
||||
// C Thread Transfer from descriptor
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.epilogue_c;
|
||||
static constexpr auto DL_C_TRANSFER = ALGORITHM.transfer.c.epilogue;
|
||||
using CThreadTransferSrcDstAccessOrder = to_sequence_v<DL_C_TRANSFER.src_dst_access_order>;
|
||||
static constexpr ck::index_t CThreadTransferSrcDstVectorDim = DL_C_TRANSFER.src_dst_vector_dim;
|
||||
static constexpr ck::index_t CThreadTransferDstScalarPerVector =
|
||||
@@ -998,9 +978,9 @@ struct ConvFactory<SIGNATURE, ALGORITHM, VERSION>
|
||||
static constexpr auto BLOCK = factory_internal::SetThreadBlockInfo<BASE_ALGORITHM>();
|
||||
static constexpr auto GRIDWISE_GEMM = BASE_ALGORITHM.gridwise_gemm;
|
||||
static constexpr auto A_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvABlockTransfer<BASE_ALGORITHM>();
|
||||
factory_internal::SetFwdConvBlockTransfer<BASE_ALGORITHM.transfer.a>();
|
||||
static constexpr auto B_BLOCK_TRANSFER =
|
||||
factory_internal::SetFwdConvBBlockTransfer<BASE_ALGORITHM>();
|
||||
factory_internal::SetFwdConvBlockTransfer<BASE_ALGORITHM.transfer.b>();
|
||||
static constexpr auto C_BLOCK_TRANSFER =
|
||||
factory_internal::SetCBlockTransfer<SIGNATURE, BASE_ALGORITHM>();
|
||||
|
||||
|
||||
Reference in New Issue
Block a user