mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
WMMA grouped conv fwd large tensor extra flavors (#3582)
* Additional flavors for WMMA conv fwd large tensor
- added F16/BF16 clamp operation
- added F16/BF16 bias_clamp operation
- small modification to the device code to accomodate extra tensors
* changed strategy to handle GemmArgs array
* Adding generic instance
* Added generic instance to clamp and bias_clamp ops
[ROCm/composable_kernel commit: 81ee19bd2c]
This commit is contained in:
committed by
GitHub
parent
237d22d6ca
commit
21ab5fbe49
@@ -617,32 +617,32 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
|
||||
const auto m_block = GridwiseGemm::CalculateMBlock(gemm_m);
|
||||
const auto n_block = GridwiseGemm::CalculateNBlock(gemm_n);
|
||||
|
||||
GemmArgs new_args{};
|
||||
new_args.a_ptrs_ = p_as_grid;
|
||||
new_args.b_ptrs_ = p_bs_grid;
|
||||
new_args.ds_ptrs_ = p_ds_grid;
|
||||
new_args.e_ptr_ = p_e_grid;
|
||||
|
||||
new_args.a_element_op_ = a_element_op_;
|
||||
new_args.b_element_op_ = b_element_op_;
|
||||
new_args.cde_element_op_ = cde_element_op_;
|
||||
|
||||
new_args.M_ = gemm_m;
|
||||
new_args.N_ = gemm_n;
|
||||
|
||||
new_args.a_grid_desc_ = a_grid_desc;
|
||||
new_args.b_grid_desc_ = b_grid_desc;
|
||||
new_args.ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
const auto ds_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
ds_grid_desc_m_n, m_block, n_block);
|
||||
new_args.e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
const auto e_desc_mblock_mperblock_nblock_nperblock =
|
||||
GridwiseGemm::MakeDEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(
|
||||
e_grid_desc_m_n, m_block, n_block);
|
||||
|
||||
new_args.BlockStart_ = BlockStart;
|
||||
new_args.BlockEnd_ = BlockEnd;
|
||||
|
||||
gemm_desc_kernel_args_.At(valid_gemms_count_) = new_args;
|
||||
gemm_desc_kernel_args_.Emplace(
|
||||
valid_gemms_count_,
|
||||
GemmArgs{.a_ptrs_ = p_as_grid,
|
||||
.b_ptrs_ = p_bs_grid,
|
||||
.ds_ptrs_ = p_ds_grid,
|
||||
.e_ptr_ = p_e_grid,
|
||||
.a_element_op_ = a_element_op_,
|
||||
.b_element_op_ = b_element_op_,
|
||||
.cde_element_op_ = cde_element_op_,
|
||||
.M_ = gemm_m,
|
||||
.N_ = gemm_n,
|
||||
.a_grid_desc_ = a_grid_desc,
|
||||
.b_grid_desc_ = b_grid_desc,
|
||||
.ds_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
ds_desc_mblock_mperblock_nblock_nperblock,
|
||||
.e_grid_desc_mblock_mperblock_nblock_nperblock_ =
|
||||
e_desc_mblock_mperblock_nblock_nperblock,
|
||||
.BlockStart_ = BlockStart,
|
||||
.BlockEnd_ = BlockEnd});
|
||||
|
||||
valid_gemms_count_++;
|
||||
}
|
||||
@@ -789,11 +789,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
|
||||
compute_ptr_offset_of_n_.BatchStrideA_ = a_g_n_c_wis_strides_[1] * conv_N_per_block_;
|
||||
compute_ptr_offset_of_n_.BatchStrideE_ = e_g_n_k_wos_strides_[1] * conv_N_per_block_;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
|
||||
});
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
compute_ptr_offset_of_groups_.BatchStrideDs_(i) = ds_g_n_k_wos_strides_[i][0];
|
||||
compute_ptr_offset_of_n_.BatchStrideDs_(i) =
|
||||
ds_g_n_k_wos_strides_[i][1] * conv_N_per_block_;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void Print() const
|
||||
@@ -807,12 +810,15 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
|
||||
<< ", is_split_valid=" << std::boolalpha << is_split_valid_
|
||||
<< std::noboolalpha << ", grid_size=" << grid_size_ << std::endl;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
std::cout << " Ds[" << i.value
|
||||
<< "] group stride=" << compute_ptr_offset_of_groups_.BatchStrideDs_(i)
|
||||
<< ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_(i)
|
||||
<< std::endl;
|
||||
});
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto i) {
|
||||
std::cout << " Ds[" << i.value << "] group stride="
|
||||
<< compute_ptr_offset_of_groups_.BatchStrideDs_.At(i)
|
||||
<< ", n stride=" << compute_ptr_offset_of_n_.BatchStrideDs_.At(i)
|
||||
<< std::endl;
|
||||
});
|
||||
}
|
||||
|
||||
std::cout << "===== GEMM splits =====" << std::endl;
|
||||
for(index_t i = 0; i < valid_gemms_count_; ++i)
|
||||
@@ -836,11 +842,14 @@ struct DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor
|
||||
std::cout << " E[MBlock, MPerBlock, NBlock, NPerBlock]: "
|
||||
<< gemm.e_grid_desc_mblock_mperblock_nblock_nperblock_ << std::endl;
|
||||
|
||||
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
|
||||
std::cout << " D" << d_idx.value << " descriptor: "
|
||||
<< gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_(d_idx)
|
||||
<< std::endl;
|
||||
});
|
||||
if constexpr(NumDTensor > 0)
|
||||
{
|
||||
static_for<0, NumDTensor, 1>{}([&](auto d_idx) {
|
||||
std::cout << " D" << d_idx.value << " descriptor: "
|
||||
<< gemm.ds_grid_desc_mblock_mperblock_nblock_nperblock_.At(d_idx)
|
||||
<< std::endl;
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
|
||||
#include "functional2.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include <type_traits>
|
||||
#include <cassert>
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -27,6 +29,15 @@ struct Array
|
||||
|
||||
__host__ __device__ constexpr TData& operator()(index_t i) { return At(i); }
|
||||
|
||||
template <typename... Args>
|
||||
__host__ constexpr auto Emplace(index_t i, Args&&... args)
|
||||
-> std::enable_if_t<std::is_nothrow_constructible_v<TData, Args&&...>>
|
||||
{
|
||||
assert(i >= 0 && i < NSize);
|
||||
mData[i].~TData();
|
||||
new(mData + i) TData(ck::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr auto operator=(const T& a)
|
||||
{
|
||||
|
||||
@@ -29,12 +29,32 @@ using S = ck::Sequence<Is...>;
|
||||
using namespace ck::tensor_layout::convolution;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using AddClamp = ck::tensor_operation::element_wise::AddClamp;
|
||||
using Clamp = ck::tensor_operation::element_wise::Clamp;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto GemmMNKPadding = ck::tensor_operation::device::GemmSpecialization::MNKPadding;
|
||||
|
||||
template <index_t NDSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec,
|
||||
typename DsDataType = Empty_Tuple,
|
||||
typename CDEElementOp = PassThrough>
|
||||
using device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
|
||||
//########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
|
||||
//########################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -45,9 +65,10 @@ template <index_t NDSpatial,
|
||||
typename CDEElementOp = PassThrough>
|
||||
using device_grouped_conv_fwd_wmma_large_tensor_f16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
|
||||
//########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
|
||||
//########################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor<NDSpatial, ALayout, BLayout, DsLayout, ELayout, F16, F16, F32, F16, DsDataType, F16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
@@ -56,6 +77,24 @@ using device_grouped_conv_fwd_wmma_large_tensor_f16_instances = std::tuple<
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
typename DsLayout,
|
||||
typename ELayout,
|
||||
ConvolutionForwardSpecialization ConvSpec,
|
||||
typename DsDataType = Empty_Tuple,
|
||||
typename CDEElementOp = PassThrough>
|
||||
using device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
|
||||
//########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
|
||||
//########################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor<NDSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataType, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 64, 64, 64, 32, 8, 16, 16, 4, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 8, 1, 1, 1, S<1, 16, 1, 4>, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
template <index_t NDSpatial,
|
||||
typename ALayout,
|
||||
typename BLayout,
|
||||
@@ -66,9 +105,10 @@ template <index_t NDSpatial,
|
||||
typename CDEElementOp = PassThrough>
|
||||
using device_grouped_conv_fwd_wmma_large_tensor_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
//########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MRepeat| NRepeat| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| | | Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _MBlock_MWaveMPerXdl| ScalarPerVector|
|
||||
//########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
//########################################################| NumDim| A| B| Ds| E| AData| BData| AccData| CShuffle| Ds| EData| A| B| CDE| ConvForward| GEMM| Block| MPer| NPer| KPer| K1| MPer| NPer| MWmma| NWmma| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer|
|
||||
//########################################################| Spatial| Layout| Layout| Layout| Layout| Type| Type| Type| DataType| DataType| Type| Elementwise| Elementwise| Elementwise| Specialization| Specialization| Size| Block| Block| Block| | WMMA| WMMA| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN|MWmmaPerWave|NWmmaPerWave| _MBlock_MWaveMPerWmma| ScalarPerVector|
|
||||
//########################################################| | | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerWmma| _NWaveNPerWmma|
|
||||
//########################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor<NDSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataType, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 128, 128, 64, 8, 16, 16, 4, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor<NDSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataType, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 256, 64, 64, 64, 8, 16, 16, 2, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 8>, 1>,
|
||||
DeviceGroupedConvFwdMultipleD_Wmma_CShuffle_V3_Large_Tensor<NDSpatial, ALayout, BLayout, DsLayout, ELayout, BF16, BF16, F32, BF16, DsDataType, BF16, PassThrough, PassThrough, CDEElementOp, ConvSpec, GemmMNKPadding, 128, 128, 128, 32, 8, 16, 16, 4, 4, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, 1, 1, 1, S<1, 32, 1, 4>, 1>,
|
||||
|
||||
@@ -293,8 +293,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part4(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
// op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
@@ -310,8 +312,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part4(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
// op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -334,8 +338,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
// op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
@@ -351,8 +357,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
// op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -63,20 +63,33 @@ void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
// void
|
||||
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
// NHWGC,
|
||||
// GKYXC,
|
||||
// Tuple<NHWGK>,
|
||||
// NHWGK,
|
||||
// BF16,
|
||||
// BF16,
|
||||
// Tuple<BF16>,
|
||||
// BF16,
|
||||
// PassThrough,
|
||||
// PassThrough,
|
||||
// AddClamp>>>& instances);
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
@@ -131,20 +144,33 @@ void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhw
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
// void
|
||||
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
// NDHWGC,
|
||||
// GKZYXC,
|
||||
// Tuple<NDHWGK>,
|
||||
// NDHWGK,
|
||||
// BF16,
|
||||
// BF16,
|
||||
// Tuple<BF16>,
|
||||
// BF16,
|
||||
// PassThrough,
|
||||
// PassThrough,
|
||||
// AddClamp>>>& instances);
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -203,20 +229,33 @@ void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
// void
|
||||
// add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
// NHWGC,
|
||||
// GKYXC,
|
||||
// Tuple<NHWGK>,
|
||||
// NHWGK,
|
||||
// F16,
|
||||
// F16,
|
||||
// Tuple<F16>,
|
||||
// F16,
|
||||
// PassThrough,
|
||||
// PassThrough,
|
||||
// AddClamp>>>& instances);
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part1(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
@@ -271,20 +310,33 @@ void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhw
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
// void
|
||||
// add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
// NDHWGC,
|
||||
// GKZYXC,
|
||||
// Tuple<NDHWGK>,
|
||||
// NDHWGK,
|
||||
// F16,
|
||||
// F16,
|
||||
// Tuple<F16>,
|
||||
// F16,
|
||||
// PassThrough,
|
||||
// PassThrough,
|
||||
// AddClamp>>>& instances);
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
@@ -290,8 +290,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_instances_part4(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
// op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
@@ -307,8 +309,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instances_part4(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
// op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
@@ -331,8 +335,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part4(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
// op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
#ifdef CK_ENABLE_FP16
|
||||
@@ -348,8 +354,10 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceGroupe
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instances_part4(
|
||||
op_ptrs);
|
||||
// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
// op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
op_ptrs);
|
||||
add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instances(
|
||||
op_ptrs);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
@@ -63,20 +63,33 @@ void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_bf16_
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
// void
|
||||
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
// NHWGC,
|
||||
// GKYXC,
|
||||
// Tuple<>,
|
||||
// NHWGK,
|
||||
// BF16,
|
||||
// BF16,
|
||||
// Tuple<>,
|
||||
// BF16,
|
||||
// PassThrough,
|
||||
// PassThrough,
|
||||
// Clamp>>>& instances);
|
||||
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf16_instances_part1(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
@@ -131,20 +144,33 @@ void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_bf
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
// void
|
||||
// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
// NDHWGC,
|
||||
// GKZYXC,
|
||||
// Tuple<>,
|
||||
// NDHWGK,
|
||||
// BF16,
|
||||
// BF16,
|
||||
// Tuple<>,
|
||||
// BF16,
|
||||
// PassThrough,
|
||||
// PassThrough,
|
||||
// Clamp>>>& instances);
|
||||
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -256,35 +282,61 @@ void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f1
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
// void
|
||||
// add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
// NHWGC,
|
||||
// GKYXC,
|
||||
// Tuple<>,
|
||||
// NHWGK,
|
||||
// F16,
|
||||
// F16,
|
||||
// Tuple<>,
|
||||
// F16,
|
||||
// PassThrough,
|
||||
// PassThrough,
|
||||
// Clamp>>>& instances);
|
||||
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
// void
|
||||
// add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
// std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
// NDHWGC,
|
||||
// GKZYXC,
|
||||
// Tuple<>,
|
||||
// NDHWGK,
|
||||
// F16,
|
||||
// F16,
|
||||
// Tuple<>,
|
||||
// F16,
|
||||
// PassThrough,
|
||||
// PassThrough,
|
||||
// Clamp>>>& instances);
|
||||
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances);
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
@@ -49,4 +49,8 @@ add_instance_library(device_grouped_conv2d_fwd_bias_clamp_instance
|
||||
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part2.cpp
|
||||
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part3.cpp
|
||||
wmma/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part4.cpp
|
||||
wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instance.cpp
|
||||
wmma/large_tensor/device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F16>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<NHWGK>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F16>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -49,4 +49,8 @@ add_instance_library(device_grouped_conv2d_fwd_clamp_instance
|
||||
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part2.cpp
|
||||
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part3.cpp
|
||||
wmma/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_nhwgc_gkyxc_nhwgk_f16_instance_part4.cpp
|
||||
wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instance.cpp
|
||||
wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instance.cpp
|
||||
wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instance.cpp
|
||||
wmma/large_tensor/device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instance.cpp
|
||||
)
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_bf16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv2d_fwd_clamp_wmma_cshufflev3_large_tensor_nhwgc_gkyxc_nhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_f16_instances<2,
|
||||
NHWGC,
|
||||
GKYXC,
|
||||
Tuple<>,
|
||||
NHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -44,6 +44,10 @@ set(GROUPED_CONV3D_FWD
|
||||
wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp
|
||||
wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp
|
||||
wmma/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp
|
||||
wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instance.cpp
|
||||
wmma/large_tensor/device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_bias_clamp_instance ${GROUPED_CONV3D_FWD})
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<BF16>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<BF16>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F16>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_fwd_bias_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<F16>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
AddClamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<NDHWGK>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<F16>,
|
||||
AddClamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -44,6 +44,10 @@ set(GROUPED_CONV3D_FWD
|
||||
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part2.cpp
|
||||
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part3.cpp
|
||||
wmma/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_ndhwgc_gkzyxc_ndhwgk_f16_instance_part4.cpp
|
||||
wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instance.cpp
|
||||
wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instance.cpp
|
||||
wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instance.cpp
|
||||
wmma/large_tensor/device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instance.cpp
|
||||
)
|
||||
|
||||
add_instance_library(device_grouped_conv3d_fwd_clamp_instance ${GROUPED_CONV3D_FWD})
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_bf16_generic_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_bf16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
BF16,
|
||||
BF16,
|
||||
Tuple<>,
|
||||
BF16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_bf16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_generic_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_f16_generic_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
|
||||
#include "ck/library/tensor_operation_instance/gpu/grouped_conv_fwd/device_grouped_conv_fwd_wmma_cshufflev3_large_tensor_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace instance {
|
||||
void add_device_grouped_conv3d_fwd_clamp_wmma_cshufflev3_large_tensor_ndhwgc_gkzyxc_ndhwgk_f16_instances(
|
||||
std::vector<std::unique_ptr<DeviceGroupedConvFwdMultipleABD<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
F16,
|
||||
F16,
|
||||
Tuple<>,
|
||||
F16,
|
||||
PassThrough,
|
||||
PassThrough,
|
||||
Clamp>>>& instances)
|
||||
{
|
||||
add_device_operation_instances(
|
||||
instances,
|
||||
device_grouped_conv_fwd_wmma_large_tensor_f16_instances<3,
|
||||
NDHWGC,
|
||||
GKZYXC,
|
||||
Tuple<>,
|
||||
NDHWGK,
|
||||
ConvFwdDefault,
|
||||
Tuple<>,
|
||||
Clamp>{});
|
||||
}
|
||||
|
||||
} // namespace instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
Reference in New Issue
Block a user