mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
ck-builder: group transfer operations per tensor (#3217)
Grouping transfer operations per tensor makes it easier to
constrain on and operate with the transfer operations. As an
example, we can now deduplicate the logic for translating
the transfer operations from the ck-builder interface to the old
ck interface for the A and B tensors.
[ROCm/composable_kernel commit: 245c6011cf]
This commit is contained in:
@@ -103,18 +103,25 @@ struct AccessOrder
|
||||
};
|
||||
static_assert(AccessOrderDescriptor<AccessOrder>);
|
||||
|
||||
struct BlockTransferABC
|
||||
struct TransferAB
|
||||
{
|
||||
BlockTransfer block_transfer_a;
|
||||
BlockTransfer block_transfer_b;
|
||||
ThreadCluster thread_cluster_dims_c;
|
||||
LdsTransfer lds_transfer_a;
|
||||
LdsTransfer lds_transfer_b;
|
||||
Epilogue epilogue_c;
|
||||
AccessOrder block_transfer_access_order_a;
|
||||
AccessOrder block_transfer_access_order_b;
|
||||
AccessOrder src_access_order_a;
|
||||
AccessOrder src_access_order_b;
|
||||
BlockTransfer block_transfer;
|
||||
LdsTransfer lds_transfer;
|
||||
AccessOrder block_transfer_access_order;
|
||||
AccessOrder src_access_order;
|
||||
};
|
||||
|
||||
struct TransferC
|
||||
{
|
||||
ThreadCluster thread_cluster_dims;
|
||||
Epilogue epilogue;
|
||||
};
|
||||
|
||||
struct TransferABC
|
||||
{
|
||||
TransferAB a;
|
||||
TransferAB b;
|
||||
TransferC c;
|
||||
};
|
||||
|
||||
// DL-specific descriptors
|
||||
@@ -172,9 +179,9 @@ struct WmmaGemm_
|
||||
GridwiseWmmaGemm gridwise_gemm;
|
||||
};
|
||||
|
||||
struct BlockTransfer_
|
||||
struct Transfer_
|
||||
{
|
||||
BlockTransferABC block_transfer;
|
||||
TransferABC transfer;
|
||||
};
|
||||
|
||||
struct ConvSpecialization_
|
||||
@@ -205,15 +212,26 @@ struct DlThreadCluster_
|
||||
DlThreadCluster thread_cluster;
|
||||
};
|
||||
|
||||
struct DlBlockTransfer_
|
||||
struct DlBlockTransferAB
|
||||
{
|
||||
DlBlockTransfer block_transfer_a;
|
||||
DlBlockTransfer block_transfer_b;
|
||||
DlBlockTransfer block_transfer;
|
||||
};
|
||||
|
||||
struct DlEpilogue_
|
||||
struct DlBlockTransferC
|
||||
{
|
||||
DlEpilogue epilogue_c;
|
||||
DlEpilogue epilogue;
|
||||
};
|
||||
|
||||
struct DlTransferABC
|
||||
{
|
||||
DlBlockTransferAB a;
|
||||
DlBlockTransferAB b;
|
||||
DlBlockTransferC c;
|
||||
};
|
||||
|
||||
struct DlTransfer_
|
||||
{
|
||||
DlTransferABC transfer;
|
||||
};
|
||||
|
||||
// Specialization wrapper for large tensor support
|
||||
@@ -255,12 +273,12 @@ struct ConvAlgorithmTemplate : Components...
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BT>
|
||||
constexpr auto with_block_transfer(const BT& bt) const
|
||||
template <typename T>
|
||||
constexpr auto with_transfer(const T& t) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<BlockTransfer_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_transfer = bt;
|
||||
static_assert(std::is_base_of_v<Transfer_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.transfer = t;
|
||||
return result;
|
||||
}
|
||||
|
||||
@@ -313,21 +331,12 @@ struct ConvAlgorithmTemplate : Components...
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename BTA, typename BTB>
|
||||
constexpr auto with_dl_block_transfer(const BTA& bta, const BTB& btb) const
|
||||
template <typename T>
|
||||
constexpr auto with_dl_transfer(const T& t) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<DlBlockTransfer_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.block_transfer_a = bta;
|
||||
result.block_transfer_b = btb;
|
||||
return result;
|
||||
}
|
||||
|
||||
constexpr auto with_dl_epilogue(const DlEpilogue& epi) const
|
||||
{
|
||||
static_assert(std::is_base_of_v<DlEpilogue_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.epilogue_c = epi;
|
||||
static_assert(std::is_base_of_v<DlTransfer_, ConvAlgorithmTemplate>);
|
||||
auto result = *this;
|
||||
result.transfer = t;
|
||||
return result;
|
||||
}
|
||||
};
|
||||
@@ -335,20 +344,19 @@ struct ConvAlgorithmTemplate : Components...
|
||||
// Algorithm types
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, BlockTransfer_, ConvSpecialization_, Prefetch_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, Transfer_, ConvSpecialization_, Prefetch_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3 =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, BlockTransfer_, ConvSpecialization_, BlockGemm_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_, XdlGemm_, Transfer_, ConvSpecialization_, BlockGemm_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Wmma_CShuffle =
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, BlockTransfer_, ConvSpecialization_, Prefetch_>;
|
||||
ConvAlgorithmTemplate<ThreadBlock_, WmmaGemm_, Transfer_, ConvSpecialization_, Prefetch_>;
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK =
|
||||
ConvAlgorithmTemplate<ThreadBlock_,
|
||||
ConvSpecialization_,
|
||||
DlThreadConfig_,
|
||||
DlThreadCluster_,
|
||||
DlBlockTransfer_,
|
||||
DlEpilogue_>;
|
||||
DlTransfer_>;
|
||||
|
||||
using ConvAlgorithm_DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor =
|
||||
LargeTensorWrapper<ConvAlgorithm_DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle>;
|
||||
|
||||
Reference in New Issue
Block a user