mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
Testing all fwd convolution specializations. (#259)
* UniforFill with integer values. * Log tested instance type string. * Add UT for all convolution specializations. * debugging conv * Fix dangling reference bug. * Small refinements. * Fix call to error checking function. * Small refinements to tests. * Configure error tolerance * Change problem size. * Remove OddC case from types that do not support it. * Add helper traits for AccumulatorDataType. * Print first 5 errs in check_err for integral types. * Rename FillUniform to FillUniformDistribution * Refactor * Do not use typed tests. * Instead use plain fixture class with templatized member functions. * Initialize tensors with integer values. * Refine test instances. * Properly set accumulator data type. * Add another "big" instance. * Refactor convolution tests. * Revert "debugging conv" This reverts commitb109516455. * Add pragma once + format + small refinement. * Fix some unwanted changes. * Clang-format * Fix profile_convnd to use renamed tensor initializer. * Add instances for ConvFWDND kernel case 2D * Helpers to get ConvNDFwd 2D instances. * Refactoring. * Remove "small block" instance as it was generating compiler errors. * Remove default template parameters values. * Refine and fix test. * Fix problem with default template parameter types. * Adjust error thresholds for floating point values test. * Use integer values initialization for instances test. * Add tests for ConvNDFwd 2D case. * Remove AccumulatorDataType type trait. * Update unit-tests. * Remove operator<< overload. * Unlock conv1d/3d nd fwd instances. * Enable skipping calculating reference using flag. * Fix number of channels for first ResNet50 layer. * Clang-format. Co-authored-by: Adam Osewski <aosewski@amd.com> Co-authored-by: Chao Liu <chao.liu2@amd.com> [ROCm/composable_kernel commit:a2edd7d802]
This commit is contained in:
@@ -291,8 +291,8 @@ int main(int argc, char* argv[])
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / ave_time;
|
||||
float gb_per_sec = num_btype / 1.E6 / ave_time;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, " << conv->GetTypeString()
|
||||
<< std::endl;
|
||||
std::cout << "Perf: " << ave_time << " ms, " << tflops << " TFlops, " << gb_per_sec << " GB/s, "
|
||||
<< conv->GetTypeString() << std::endl;
|
||||
|
||||
if(do_verification)
|
||||
{
|
||||
|
||||
@@ -163,10 +163,6 @@
|
||||
// tuning parameter
|
||||
#define CK_WORKAROUND_SWDEV_325164 1
|
||||
|
||||
// workaround for verification failure ConvNd forward
|
||||
// https://github.com/ROCmSoftwarePlatform/composable_kernel/issues/135
|
||||
#define CK_WORKAROUND_GITHUB_135 1
|
||||
|
||||
namespace ck {
|
||||
|
||||
enum struct InMemoryDataOperationEnum
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#ifndef CHECK_ERR_HPP
|
||||
#define CHECK_ERR_HPP
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
@@ -169,17 +168,34 @@ check_err(const std::vector<T>& out,
|
||||
return false;
|
||||
}
|
||||
|
||||
bool res{true};
|
||||
int err_count = 0;
|
||||
int64_t err = 0;
|
||||
int64_t max_err = std::numeric_limits<int64_t>::min();
|
||||
for(std::size_t i = 0; i < ref.size(); ++i)
|
||||
{
|
||||
if(out[i] != ref[i])
|
||||
int64_t o = out[i];
|
||||
int64_t r = ref[i];
|
||||
err = std::abs(o - r);
|
||||
|
||||
if(err > 0)
|
||||
{
|
||||
std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast<int>(out[i])
|
||||
<< " != " << static_cast<int>(ref[i]) << std::endl
|
||||
<< msg << std::endl;
|
||||
return false;
|
||||
max_err = err > max_err ? err : max_err;
|
||||
err_count++;
|
||||
if(err_count < 5)
|
||||
{
|
||||
std::cout << "out[" << i << "] != ref[" << i << "]: " << static_cast<int>(out[i])
|
||||
<< " != " << static_cast<int>(ref[i]) << std::endl
|
||||
<< msg << std::endl;
|
||||
}
|
||||
res = false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
if(!res)
|
||||
{
|
||||
std::cout << "max err: " << max_err << std::endl;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
@@ -191,5 +207,3 @@ std::ostream& operator<<(std::ostream& os, const std::vector<T>& v)
|
||||
std::copy(std::begin(v), std::end(v), std::ostream_iterator<T>(os, " "));
|
||||
return os;
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -402,8 +402,8 @@ template <typename InDataType,
|
||||
typename InElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
|
||||
typename WeiElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
|
||||
typename OutElementwiseOp = ck::tensor_operation::element_wise::PassThrough,
|
||||
typename InputInitFun = FillUniform<InDataType>,
|
||||
typename WeightsInitFun = FillUniform<WeiDataType>>
|
||||
typename InputInitFun = FillUniformDistribution<InDataType>,
|
||||
typename WeightsInitFun = FillUniformDistribution<WeiDataType>>
|
||||
class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType, WeiDataType>
|
||||
{
|
||||
using DeviceConvFwdOp = tensor_operation::device::
|
||||
@@ -422,8 +422,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
|
||||
|
||||
ConvFwdOpInstance(const ConvParams& params,
|
||||
bool do_init = true,
|
||||
const InputInitFun& input_init_f = InputInitFun{},
|
||||
const WeightsInitFun& weights_init_f = WeightsInitFun{})
|
||||
const InputInitFun& input_init_f = InputInitFun(),
|
||||
const WeightsInitFun& weights_init_f = WeightsInitFun())
|
||||
: BaseType(),
|
||||
params_{params},
|
||||
output_spatial_lengths_{params.GetOutputSpatialLengths()},
|
||||
@@ -560,8 +560,8 @@ class ConvFwdOpInstance : public ck::utils::OpInstance<OutDataType, InDataType,
|
||||
const ConvParams& params_;
|
||||
const std::vector<ck::index_t> output_spatial_lengths_;
|
||||
const bool do_init_;
|
||||
const InputInitFun& input_init_f_;
|
||||
const WeightsInitFun& weights_init_f_;
|
||||
InputInitFun input_init_f_;
|
||||
WeightsInitFun weights_init_f_;
|
||||
};
|
||||
|
||||
} // namespace conv
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <random>
|
||||
|
||||
#include "data_type.hpp"
|
||||
@@ -8,46 +9,56 @@
|
||||
namespace ck {
|
||||
namespace utils {
|
||||
|
||||
// template <typename T, class Enable = void>
|
||||
// struct FillUniform;
|
||||
|
||||
// TODO: what's wrong with this specialization???
|
||||
// err: segmentation fault in mt19937 - infinite loop like.
|
||||
// template <typename T>
|
||||
// struct FillUniform<T, typename std::enable_if<std::is_integral<T>::value &&
|
||||
// !std::is_same<T, bhalf_t>::value>::type>
|
||||
// {
|
||||
// int a_{0};
|
||||
// int b_{5};
|
||||
// // T a_ = T{0};
|
||||
// // T b_ = T{5};
|
||||
|
||||
// template <typename ForwardIter>
|
||||
// void operator()(ForwardIter first, ForwardIter last) const
|
||||
// {
|
||||
// std::mt19937 gen{11939};
|
||||
// std::uniform_int_distribution<int> dis(a_, b_);
|
||||
// std::generate(first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); });
|
||||
// }
|
||||
// };
|
||||
|
||||
// struct FillUniform<T, typename std::enable_if<std::is_floating_point<T>::value ||
|
||||
// std::is_same<T, bhalf_t>::value>::type>
|
||||
template <typename T>
|
||||
struct FillUniform
|
||||
struct FillUniformDistribution
|
||||
{
|
||||
float a_{0};
|
||||
float b_{5};
|
||||
float a_{-5.f};
|
||||
float b_{5.f};
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::mt19937 gen{11939};
|
||||
std::uniform_real_distribution<> dis(a_, b_);
|
||||
std::mt19937 gen(11939);
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
std::generate(first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); });
|
||||
}
|
||||
};
|
||||
|
||||
// Normally FillUniformDistributionIntegerValue should use std::uniform_int_distribution as below.
|
||||
// However this produces segfaults in std::mt19937 which look like inifite loop.
|
||||
// template <typename T>
|
||||
// struct FillUniformDistributionIntegerValue
|
||||
// {
|
||||
// int a_{-5};
|
||||
// int b_{5};
|
||||
//
|
||||
// template <typename ForwardIter>
|
||||
// void operator()(ForwardIter first, ForwardIter last) const
|
||||
// {
|
||||
// std::mt19937 gen(11939);
|
||||
// std::uniform_int_distribution<int> dis(a_, b_);
|
||||
// std::generate(
|
||||
// first, last, [&dis, &gen]() { return ck::type_convert<T>(dis(gen)); });
|
||||
// }
|
||||
// };
|
||||
|
||||
// Workaround for uniform_int_distribution not working as expected. See note above.<
|
||||
template <typename T>
|
||||
struct FillUniformDistributionIntegerValue
|
||||
{
|
||||
float a_{-5.f};
|
||||
float b_{5.f};
|
||||
|
||||
template <typename ForwardIter>
|
||||
void operator()(ForwardIter first, ForwardIter last) const
|
||||
{
|
||||
std::mt19937 gen(11939);
|
||||
std::uniform_real_distribution<float> dis(a_, b_);
|
||||
std::generate(
|
||||
first, last, [&dis, &gen]() { return ck::type_convert<T>(std::round(dis(gen))); });
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct FillMonotonicSeq
|
||||
{
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <stdexcept>
|
||||
@@ -78,7 +79,8 @@ class OpInstanceRunEngine
|
||||
|
||||
template <typename ReferenceOp = std::function<void()>>
|
||||
OpInstanceRunEngine(const OpInstanceT& op_instance,
|
||||
const ReferenceOp& reference_op = ReferenceOp{})
|
||||
const ReferenceOp& reference_op = ReferenceOp{},
|
||||
bool do_verification = true)
|
||||
: op_instance_{op_instance}
|
||||
{
|
||||
in_tensors_ = op_instance_.GetInputTensors();
|
||||
@@ -88,8 +90,11 @@ class OpInstanceRunEngine
|
||||
const Tensor<InArgTypes>&...,
|
||||
Tensor<OutDataType>&>)
|
||||
{
|
||||
ref_output_ = op_instance_.GetOutputTensor();
|
||||
CallRefOpUnpackArgs(reference_op, std::make_index_sequence<kNInArgs_>{});
|
||||
if(do_verification)
|
||||
{
|
||||
ref_output_ = op_instance_.GetOutputTensor();
|
||||
CallRefOpUnpackArgs(reference_op, std::make_index_sequence<kNInArgs_>{});
|
||||
}
|
||||
}
|
||||
AllocateDeviceInputTensors(std::make_index_sequence<kNInArgs_>{});
|
||||
out_device_buffer_ =
|
||||
@@ -110,6 +115,7 @@ class OpInstanceRunEngine
|
||||
op_ptr.get(), in_device_buffers_, out_device_buffer_);
|
||||
if(op_ptr->IsSupportedArgument(argument.get()))
|
||||
{
|
||||
std::cout << "Testing instance: " << op_ptr->GetTypeString() << std::endl;
|
||||
invoker->Run(argument.get());
|
||||
out_device_buffer_->FromDevice(out_tensor_->mData.data());
|
||||
if(!ref_output_)
|
||||
@@ -119,9 +125,16 @@ class OpInstanceRunEngine
|
||||
" You have to provide reference function.");
|
||||
}
|
||||
// TODO: enable flexible use of custom check_error functions
|
||||
res = res && check_err(out_tensor_->mData, ref_output_->mData);
|
||||
bool inst_res = CheckErr(out_tensor_->mData, ref_output_->mData);
|
||||
std::cout << (inst_res ? "SUCCESS" : "FAILURE") << std::endl;
|
||||
res = res && inst_res;
|
||||
out_device_buffer_->SetZero();
|
||||
}
|
||||
else
|
||||
{
|
||||
std::cout << "Given conv problem is not supported by instance: \n\t>>>>"
|
||||
<< op_ptr->GetTypeString() << std::endl;
|
||||
}
|
||||
}
|
||||
return res;
|
||||
}
|
||||
@@ -132,7 +145,6 @@ class OpInstanceRunEngine
|
||||
bool do_verification = false,
|
||||
bool do_log = false)
|
||||
{
|
||||
bool res{true};
|
||||
ProfileBestConfig best_config;
|
||||
|
||||
for(auto& op_ptr : op_ptrs)
|
||||
@@ -153,7 +165,7 @@ class OpInstanceRunEngine
|
||||
std::cout << "Perf: " << avg_time << " ms, " << tflops << " TFlops, " << gb_per_sec
|
||||
<< " GB/s, " << op_name << std::endl;
|
||||
|
||||
if(tflops < best_config.best_tflops)
|
||||
if(avg_time < best_config.best_avg_time)
|
||||
{
|
||||
best_config.best_op_name = op_name;
|
||||
best_config.best_tflops = tflops;
|
||||
@@ -171,7 +183,7 @@ class OpInstanceRunEngine
|
||||
" You have to provide reference function.");
|
||||
}
|
||||
// TODO: enable flexible use of custom check_error functions
|
||||
res = res && CheckErr(out_tensor_->mData, ref_output_->mData);
|
||||
CheckErr(out_tensor_->mData, ref_output_->mData);
|
||||
|
||||
if(do_log) {}
|
||||
}
|
||||
@@ -223,7 +235,7 @@ class OpInstanceRunEngine
|
||||
template <typename T>
|
||||
bool CheckErr(const std::vector<T>& dev_out, const std::vector<T>& ref_out) const
|
||||
{
|
||||
return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", atol_, rtol_);
|
||||
return ck::utils::check_err(dev_out, ref_out, "Error: incorrect results!", rtol_, atol_);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -28,15 +28,12 @@ static constexpr auto ConvFwd1x1S1P0 =
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv1d_fwd_xdl_nwc_kxc_nwk_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if !CK_WORKAROUND_GITHUB_135
|
||||
// FIXME: this instance causes numerical errors.
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
#endif
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 1, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
|
||||
@@ -6,7 +6,18 @@ set(DEVICE_CONV2D_FWD_INSTANCE_SOURCE
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
|
||||
device_conv2d_fwd_xdl_c_shuffle_nhwc_kyxc_nhwk_f16_instance.cpp;
|
||||
)
|
||||
set(DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE
|
||||
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instance.cpp;
|
||||
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instance.cpp;
|
||||
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp;
|
||||
device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instance.cpp;
|
||||
)
|
||||
|
||||
add_library(device_conv2d_fwd_instance OBJECT ${DEVICE_CONV2D_FWD_INSTANCE_SOURCE})
|
||||
add_library(device_convnd_2d_fwd_instance OBJECT ${DEVICE_CONVND_2D_FWD_INSTANCE_SOURCE})
|
||||
|
||||
set_target_properties(device_conv2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
set_target_properties(device_convnd_2d_fwd_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
|
||||
|
||||
clang_tidy_check(device_conv2d_fwd_instance)
|
||||
clang_tidy_check(device_convnd_2d_fwd_instance)
|
||||
|
||||
@@ -0,0 +1,113 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_fwd_instance {
|
||||
|
||||
using BF16 = ck::bhalf_t;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvFwd1x1P0 =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
|
||||
static constexpr auto ConvFwd1x1S1P0 =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(
|
||||
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances{});
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_bf16_instances{});
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_bf16_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_fwd_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,112 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_fwd_instance {
|
||||
|
||||
using F16 = ck::half_t;
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvFwd1x1P0 =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
|
||||
static constexpr auto ConvFwd1x1S1P0 =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 8, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 8, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 8, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 8, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 8, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 8, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 8, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F16, F16, F16, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 8, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(
|
||||
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances{});
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f16_instances{});
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_fwd_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,111 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_fwd_instance {
|
||||
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvFwd1x1P0 =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
|
||||
static constexpr auto ConvFwd1x1S1P0 =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 4, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 4, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 4, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 4, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 4, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 4, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 4, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 4, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 4, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< F32, F32, F32, F32, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 4, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 4, 4, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(
|
||||
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances, device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances{});
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_f32_instances{});
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_fwd_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -0,0 +1,112 @@
|
||||
#include <stdlib.h>
|
||||
#include "config.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "device_operation_instance.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
namespace device_conv2d_fwd_instance {
|
||||
|
||||
using F32 = float;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
using S = ck::Sequence<Is...>;
|
||||
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
static constexpr auto ConvFwd1x1P0 =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0;
|
||||
|
||||
static constexpr auto ConvFwd1x1S1P0 =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0;
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1P0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
using device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances =
|
||||
std::tuple<
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization| Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 256, 128, 4, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 256, 4, 16, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 128, 4, 16, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 128, 4, 16, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 64, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 64, 128, 4, 16, 32, 32, 2, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 64, 4, 16, 32, 32, 2, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 128, 64, 4, 16, 32, 32, 2, 1, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 256, 64, 128, 4, 16, 32, 32, 1, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 128, 32, 4, 16, 32, 32, 2, 1, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 128, 32, 128, 4, 16, 32, 32, 1, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 64, 32, 4, 16, 32, 32, 2, 1, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< int8_t, int8_t, int8_t, int32_t, PassThrough, PassThrough, PassThrough, ConvFwd1x1S1P0, 2, 64, 32, 64, 4, 16, 32, 32, 1, 2, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, S<4, 16, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, true, 7, 1>
|
||||
// clang-format on
|
||||
>;
|
||||
|
||||
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(
|
||||
std::vector<DeviceConvFwdPtr<PassThrough, PassThrough, PassThrough>>& instances)
|
||||
{
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances{});
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_p0_int8_instances{});
|
||||
add_device_operation_instances(instances,
|
||||
device_conv2d_fwd_xdl_nhwc_kyxc_nhwk_1x1_s1_p0_int8_instances{});
|
||||
}
|
||||
|
||||
} // namespace device_conv2d_fwd_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
@@ -28,15 +28,12 @@ static constexpr auto ConvFwd1x1S1P0 =
|
||||
|
||||
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
|
||||
using device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk_bf16_instances = std::tuple<
|
||||
// clang-format off
|
||||
// clang-format off
|
||||
//################################################################| InData| WeiData| OutData| AccData| In| Wei| Out| ConvForward| NumDim| Block| MPer| NPer| K0Per| K1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CThreadTransfer| CThreadTransfer|
|
||||
//################################################################| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Specialization|Spatial| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
|
||||
//################################################################| | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
|
||||
//################################################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
|
||||
#if !CK_WORKAROUND_GITHUB_135
|
||||
// FIXME: this instance causes numerical errors.
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 256, 128, 4, 8, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
#endif
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 256, 4, 8, 32, 32, 2, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 128, 128, 128, 4, 8, 32, 32, 4, 2, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K< BF16, BF16, BF16, F32, PassThrough, PassThrough, PassThrough, ConvFwdDefault, 3, 256, 128, 128, 4, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, 7, 1>,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
@@ -150,9 +151,12 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::utils::FillUniform<int>,
|
||||
ck::utils::FillUniform<int>>>(
|
||||
params, true, ck::utils::FillUniform<int>{}, ck::utils::FillUniform<int>{});
|
||||
ck::utils::FillUniformDistributionIntegerValue<int>,
|
||||
ck::utils::FillUniformDistributionIntegerValue<int>>>(
|
||||
params,
|
||||
true,
|
||||
ck::utils::FillUniformDistributionIntegerValue<int>{},
|
||||
ck::utils::FillUniformDistributionIntegerValue<int>{});
|
||||
break;
|
||||
case 2:
|
||||
conv_instance = std::make_unique<
|
||||
@@ -165,12 +169,12 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::utils::FillUniform<InDataType>,
|
||||
ck::utils::FillUniform<WeiDataType>>>(
|
||||
ck::utils::FillUniformDistribution<InDataType>,
|
||||
ck::utils::FillUniformDistribution<WeiDataType>>>(
|
||||
params,
|
||||
true,
|
||||
ck::utils::FillUniform<InDataType>{},
|
||||
ck::utils::FillUniform<WeiDataType>{});
|
||||
ck::utils::FillUniformDistribution<InDataType>{},
|
||||
ck::utils::FillUniformDistribution<WeiDataType>{});
|
||||
break;
|
||||
default: throw std::runtime_error("Unsupported init method!");
|
||||
}
|
||||
@@ -181,8 +185,10 @@ void profile_convnd_instances_impl(const ck::utils::conv::ConvParams& params,
|
||||
_1,
|
||||
_2,
|
||||
_3);
|
||||
OpInstanceRunEngine<InDataType, WeiDataType, OutDataType> run_engine(*conv_instance,
|
||||
reference_conv_fwd_fun);
|
||||
|
||||
OpInstanceRunEngine<InDataType, WeiDataType, OutDataType> run_engine(
|
||||
*conv_instance, reference_conv_fwd_fun, do_verification);
|
||||
|
||||
auto best_conf = run_engine.Profile(
|
||||
conv::ConvolutionFwdInstances<InDataType, WeiDataType, OutDataType>::template Get<NDim>(),
|
||||
time_kernel,
|
||||
|
||||
@@ -47,7 +47,7 @@ REPEAT=$9
|
||||
#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 256 64 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 1 1 56 56 1 1 1 1 0 0 0 0
|
||||
#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 64 3 3 56 56 1 1 1 1 1 1 1 1
|
||||
#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 8 7 7 224 224 2 2 1 1 3 3 3 3
|
||||
#$DRIVER $OP $DATATYPE $IN_LAYOUT $WEI_LAYOUT $OUT_LAYOUT $VERIFY $INIT $LOG $REPEAT $N 64 3 7 7 224 224 2 2 1 1 3 3 3 3
|
||||
|
||||
|
||||
# Resnet50 fusion
|
||||
|
||||
@@ -5,7 +5,7 @@ target_link_libraries(test_conv1d_fwd PRIVATE host_tensor device_conv1d_fwd_inst
|
||||
add_dependencies(test_convnd_fwd test_conv1d_fwd)
|
||||
|
||||
add_gtest_executable(test_conv2d_fwd conv2d_fwd.cpp)
|
||||
target_link_libraries(test_conv2d_fwd PRIVATE host_tensor device_conv2d_fwd_instance conv_util)
|
||||
target_link_libraries(test_conv2d_fwd PRIVATE host_tensor device_conv2d_fwd_instance device_convnd_2d_fwd_instance conv_util)
|
||||
add_dependencies(test_convnd_fwd test_conv2d_fwd)
|
||||
|
||||
add_gtest_executable(test_conv3d_fwd conv3d_fwd.cpp)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#include <iostream>
|
||||
#include <stdexcept>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include "gtest/gtest.h"
|
||||
@@ -11,83 +10,180 @@
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
bool test_conv1d_nwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs)
|
||||
class Conv1dFwdNWCInstances : public ::testing::Test
|
||||
{
|
||||
public:
|
||||
template <typename T>
|
||||
bool test_conv1d_nwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs,
|
||||
const ck::utils::conv::ConvParams& params)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
|
||||
conv::ConvFwdOpInstance<T,
|
||||
T,
|
||||
T,
|
||||
ctl::NWC,
|
||||
ctl::KXC,
|
||||
ctl::NWK,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
FillUniformDistributionIntegerValue<T>,
|
||||
FillUniformDistributionIntegerValue<T>>
|
||||
conv_instance(params,
|
||||
true,
|
||||
FillUniformDistributionIntegerValue<T>{},
|
||||
FillUniformDistributionIntegerValue<T>{});
|
||||
auto reference_conv_fwd_fun =
|
||||
std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
run_engine.SetAtol(atol_);
|
||||
run_engine.SetRtol(rtol_);
|
||||
return run_engine.Test(conv_ptrs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool test_default()
|
||||
{
|
||||
return test_conv1d_nwc_instances<T>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<1>(), params_default_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool test_filter1x1_stride1_pad0()
|
||||
{
|
||||
return test_conv1d_nwc_instances<T>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<1>(),
|
||||
params_filter1x1_stride1_pad0_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool test_filter1x1_pad0()
|
||||
{
|
||||
return test_conv1d_nwc_instances<T>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<1>(),
|
||||
params_filter1x1_pad0_);
|
||||
}
|
||||
|
||||
static inline ck::utils::conv::ConvParams params_default_{
|
||||
1, 4, 256, 64, {3}, {71}, {2}, {2}, {2}, {2}};
|
||||
static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{
|
||||
1, 4, 256, 64, {1}, {28}, {1}, {1}, {0}, {0}};
|
||||
static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{
|
||||
1, 4, 256, 64, {1}, {28}, {2}, {1}, {0}, {0}};
|
||||
|
||||
private:
|
||||
double atol_{1e-5};
|
||||
double rtol_{1e-4};
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Conv1DFwdNWC, IntegerValues)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using T = float;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
params.num_dim_spatial_ = 1;
|
||||
params.filter_spatial_lengths_ = std::vector<ck::index_t>{3};
|
||||
params.input_spatial_lengths_ = std::vector<ck::index_t>{71};
|
||||
params.conv_filter_strides_ = std::vector<ck::index_t>{2};
|
||||
params.conv_filter_dilations_ = std::vector<ck::index_t>{1};
|
||||
params.input_left_pads_ = std::vector<ck::index_t>{1};
|
||||
params.input_right_pads_ = std::vector<ck::index_t>{1};
|
||||
ck::utils::conv::ConvParams params{1, 4, 256, 64, {3}, {36}, {1}, {2}, {2}, {2}};
|
||||
|
||||
conv::ConvFwdOpInstance<T, T, T, ctl::NWC, ctl::KCX, ctl::NWK> conv_instance(params);
|
||||
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
test::conv::get_test_convolution_fwd_instance<1, T, T, T, T>(conv_ptrs);
|
||||
conv::ConvFwdOpInstance<T,
|
||||
T,
|
||||
T,
|
||||
ctl::NWC,
|
||||
ctl::KXC,
|
||||
ctl::NWK,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
FillUniformDistributionIntegerValue<T>,
|
||||
FillUniformDistributionIntegerValue<T>>
|
||||
conv_instance(params,
|
||||
true,
|
||||
FillUniformDistributionIntegerValue<T>{},
|
||||
FillUniformDistributionIntegerValue<T>{});
|
||||
|
||||
auto reference_conv_fwd_fun =
|
||||
std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
return run_engine.Test(conv_ptrs);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Conv1DFwdNWC, TestConv1D)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
params.num_dim_spatial_ = 1;
|
||||
params.N_ = 2;
|
||||
params.K_ = 16;
|
||||
params.C_ = 4;
|
||||
params.filter_spatial_lengths_ = std::vector<ck::index_t>{3};
|
||||
params.input_spatial_lengths_ = std::vector<ck::index_t>{16};
|
||||
params.conv_filter_strides_ = std::vector<ck::index_t>{1};
|
||||
params.conv_filter_dilations_ = std::vector<ck::index_t>{1};
|
||||
params.input_left_pads_ = std::vector<ck::index_t>{1};
|
||||
params.input_right_pads_ = std::vector<ck::index_t>{1};
|
||||
|
||||
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
test::conv::get_test_convolution_fwd_instance<1>(conv_ptrs);
|
||||
conv::ConvFwdOpInstance<float, float, float, ctl::NWC, ctl::KCX, ctl::NWK> conv_instance(
|
||||
params);
|
||||
|
||||
auto reference_conv_fwd_fun = std::bind(
|
||||
conv::run_reference_convolution_forward<1, float, float, float>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<float, float, float> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
run_engine.SetAtol(1e-5);
|
||||
run_engine.SetRtol(1e-4);
|
||||
EXPECT_TRUE(run_engine.Test(conv_ptrs));
|
||||
}
|
||||
|
||||
TEST(Conv1DFwdNWC, Bf16Iinstances)
|
||||
TEST(Conv1DFwdNWC, FloatingPointValues)
|
||||
{
|
||||
EXPECT_TRUE(test_conv1d_nwc_instances<ck::bhalf_t>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>::Get<1>()));
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using T = ck::half_t;
|
||||
|
||||
ck::utils::conv::ConvParams params{1, 4, 256, 64, {3}, {36}, {1}, {2}, {2}, {2}};
|
||||
|
||||
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
test::conv::get_test_convolution_fwd_instance<1, T, T, T, float>(conv_ptrs);
|
||||
conv::ConvFwdOpInstance<T,
|
||||
T,
|
||||
T,
|
||||
ctl::NWC,
|
||||
ctl::KXC,
|
||||
ctl::NWK,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
FillUniformDistribution<T>,
|
||||
FillUniformDistribution<T>>
|
||||
conv_instance(params, true, FillUniformDistribution<T>{}, FillUniformDistribution<T>{});
|
||||
|
||||
auto reference_conv_fwd_fun =
|
||||
std::bind(conv::run_reference_convolution_forward<1, T, T, T>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
run_engine.SetAtol(0.1);
|
||||
run_engine.SetRtol(1e-2);
|
||||
EXPECT_TRUE(run_engine.Test(conv_ptrs));
|
||||
}
|
||||
|
||||
TEST(Conv1DFwdNWC, F16Instances)
|
||||
TEST_F(Conv1dFwdNWCInstances, BF16_default) { EXPECT_TRUE(this->test_default<ck::bhalf_t>()); }
|
||||
TEST_F(Conv1dFwdNWCInstances, BF16_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_conv1d_nwc_instances<ck::half_t>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<ck::half_t, ck::half_t, ck::half_t>::Get<1>()));
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::bhalf_t>());
|
||||
}
|
||||
TEST_F(Conv1dFwdNWCInstances, BF16_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<ck::bhalf_t>());
|
||||
}
|
||||
|
||||
TEST(Conv1DFwdNWC, F32Instances)
|
||||
TEST_F(Conv1dFwdNWCInstances, F16_default) { EXPECT_TRUE(this->test_default<ck::half_t>()); }
|
||||
TEST_F(Conv1dFwdNWCInstances, F16_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_conv1d_nwc_instances<float>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<float, float, float>::Get<1>()));
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::half_t>());
|
||||
}
|
||||
TEST_F(Conv1dFwdNWCInstances, F16_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<ck::half_t>());
|
||||
}
|
||||
|
||||
TEST(Conv1DFwdNWC, Int8Instances)
|
||||
TEST_F(Conv1dFwdNWCInstances, F32_default) { EXPECT_TRUE(this->test_default<float>()); }
|
||||
TEST_F(Conv1dFwdNWCInstances, F32_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_conv1d_nwc_instances<int8_t>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<int8_t, int8_t, int8_t>::Get<1>()));
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<float>());
|
||||
}
|
||||
TEST_F(Conv1dFwdNWCInstances, F32_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<float>());
|
||||
}
|
||||
|
||||
TEST_F(Conv1dFwdNWCInstances, I8_default) { EXPECT_TRUE(this->test_default<int8_t>()); }
|
||||
TEST_F(Conv1dFwdNWCInstances, I8_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<int8_t>());
|
||||
}
|
||||
TEST_F(Conv1dFwdNWCInstances, I8_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<int8_t>());
|
||||
}
|
||||
|
||||
@@ -1,91 +1,265 @@
|
||||
#include <half.hpp>
|
||||
#include <iostream>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include "gtest/gtest.h"
|
||||
|
||||
#include "ck/library/utility/conv_util.hpp"
|
||||
#include "config.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "ck/library/utility/conv_util.hpp"
|
||||
#include "conv_util.hpp"
|
||||
#include "fill.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs)
|
||||
class Conv2dFwdNHWCInstances : public ::testing::Test
|
||||
{
|
||||
public:
|
||||
template <typename T>
|
||||
bool test_conv2d_nhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs,
|
||||
const ck::utils::conv::ConvParams& params)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
|
||||
conv::ConvFwdOpInstance<T,
|
||||
T,
|
||||
T,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
FillUniformDistributionIntegerValue<T>,
|
||||
FillUniformDistributionIntegerValue<T>>
|
||||
conv_instance(params,
|
||||
true,
|
||||
FillUniformDistributionIntegerValue<T>{},
|
||||
FillUniformDistributionIntegerValue<T>{});
|
||||
auto reference_conv_fwd_fun =
|
||||
std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
run_engine.SetAtol(atol_);
|
||||
run_engine.SetRtol(rtol_);
|
||||
return run_engine.Test(conv_ptrs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool test_default(bool use_convnd = false)
|
||||
{
|
||||
if(use_convnd)
|
||||
{
|
||||
return test_conv2d_nhwc_instances<T>(
|
||||
test::conv::ConvolutionNDFwdInstances<T, T, T>::Get(2), params_default_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return test_conv2d_nhwc_instances<T>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
|
||||
params_default_);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool test_filter1x1_stride1_pad0(bool use_convnd = false)
|
||||
{
|
||||
if(use_convnd)
|
||||
{
|
||||
return test_conv2d_nhwc_instances<T>(
|
||||
test::conv::ConvolutionNDFwdInstances<T, T, T>::Get(2),
|
||||
params_filter1x1_stride1_pad0_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return test_conv2d_nhwc_instances<T>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
|
||||
params_filter1x1_stride1_pad0_);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool test_filter1x1_pad0(bool use_convnd = false)
|
||||
{
|
||||
if(use_convnd)
|
||||
{
|
||||
return test_conv2d_nhwc_instances<T>(
|
||||
test::conv::ConvolutionNDFwdInstances<T, T, T>::Get(2), params_filter1x1_pad0_);
|
||||
}
|
||||
else
|
||||
{
|
||||
return test_conv2d_nhwc_instances<T>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(),
|
||||
params_filter1x1_pad0_);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool test_oddC()
|
||||
{
|
||||
return test_conv2d_nhwc_instances<T>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<2>(), params_oddC_);
|
||||
}
|
||||
|
||||
static inline ck::utils::conv::ConvParams params_default_{
|
||||
2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}};
|
||||
static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{
|
||||
2, 4, 256, 64, {1, 1}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{
|
||||
2, 4, 256, 64, {1, 1}, {28, 28}, {2, 2}, {1, 1}, {0, 0}, {0, 0}};
|
||||
static inline ck::utils::conv::ConvParams params_oddC_{
|
||||
2, 4, 256, 3, {3, 3}, {28, 28}, {1, 1}, {1, 1}, {0, 0}, {0, 0}};
|
||||
|
||||
private:
|
||||
double atol_{1e-5};
|
||||
double rtol_{1e-4};
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Conv2DFwdNHWC, IntegerValues)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
using T = float;
|
||||
|
||||
conv::ConvParams params;
|
||||
params.num_dim_spatial_ = 2;
|
||||
params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3};
|
||||
params.input_spatial_lengths_ = std::vector<ck::index_t>{71, 71};
|
||||
params.conv_filter_strides_ = std::vector<ck::index_t>{2, 2};
|
||||
params.conv_filter_dilations_ = std::vector<ck::index_t>{1, 1};
|
||||
params.input_left_pads_ = std::vector<ck::index_t>{1, 1};
|
||||
params.input_right_pads_ = std::vector<ck::index_t>{1, 1};
|
||||
ck::utils::conv::ConvParams params{
|
||||
2, 4, 256, 64, {3, 3}, {36, 36}, {1, 1}, {2, 2}, {2, 2}, {2, 2}};
|
||||
|
||||
conv::ConvFwdOpInstance<T, T, T> conv_instance(params);
|
||||
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
test::conv::get_test_convolution_fwd_instance<2, T, T, T, T>(conv_ptrs);
|
||||
conv::ConvFwdOpInstance<T,
|
||||
T,
|
||||
T,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
FillUniformDistributionIntegerValue<T>,
|
||||
FillUniformDistributionIntegerValue<T>>
|
||||
conv_instance(params,
|
||||
true,
|
||||
FillUniformDistributionIntegerValue<T>{},
|
||||
FillUniformDistributionIntegerValue<T>{});
|
||||
|
||||
auto reference_conv_fwd_fun =
|
||||
std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
return run_engine.Test(conv_ptrs);
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Conv2DFwdNHWC, TestConv2D)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
|
||||
ck::utils::conv::ConvParams params;
|
||||
params.N_ = 2;
|
||||
params.K_ = 16;
|
||||
params.C_ = 4;
|
||||
params.input_spatial_lengths_ = std::vector<ck::index_t>{16, 16};
|
||||
params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1};
|
||||
|
||||
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
test::conv::get_test_convolution_fwd_instance<2>(conv_ptrs);
|
||||
conv::ConvFwdOpInstance<float, float, float> conv_instance(params);
|
||||
|
||||
auto reference_conv_fwd_fun = std::bind(
|
||||
conv::run_reference_convolution_forward<2, float, float, float>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<float, float, float> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
run_engine.SetAtol(1e-5);
|
||||
run_engine.SetRtol(1e-4);
|
||||
EXPECT_TRUE(run_engine.Test(conv_ptrs));
|
||||
}
|
||||
|
||||
TEST(Conv2DFwdNHWC, Bf16Instances)
|
||||
TEST(Conv2DFwdNHWC, FloatingPointValues)
|
||||
{
|
||||
EXPECT_TRUE(test_conv2d_nhwc_instances<ck::bhalf_t>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>::Get<2>()));
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
using T = ck::half_t;
|
||||
|
||||
ck::utils::conv::ConvParams params{
|
||||
2, 4, 256, 64, {3, 3}, {36, 36}, {2, 2}, {2, 2}, {2, 2}, {2, 2}};
|
||||
|
||||
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
test::conv::get_test_convolution_fwd_instance<2, T, T, T, float>(conv_ptrs);
|
||||
conv::ConvFwdOpInstance<T,
|
||||
T,
|
||||
T,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::KYXC,
|
||||
ck::tensor_layout::convolution::NHWK,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
FillUniformDistribution<T>,
|
||||
FillUniformDistribution<T>>
|
||||
conv_instance(params, true, FillUniformDistribution<T>{}, FillUniformDistribution<T>{});
|
||||
|
||||
auto reference_conv_fwd_fun =
|
||||
std::bind(conv::run_reference_convolution_forward<2, T, T, T>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
run_engine.SetAtol(2e-4);
|
||||
run_engine.SetRtol(1e-3);
|
||||
EXPECT_TRUE(run_engine.Test(conv_ptrs));
|
||||
}
|
||||
|
||||
TEST(Conv2DFwdNHWC, F16Instances)
|
||||
TEST_F(Conv2dFwdNHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default<ck::bhalf_t>()); }
|
||||
TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_conv2d_nhwc_instances<ck::half_t>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<ck::half_t, ck::half_t, ck::half_t>::Get<2>()));
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::bhalf_t>());
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, BF16_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<ck::bhalf_t>());
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, F16_default) { EXPECT_TRUE(this->test_default<ck::half_t>()); }
|
||||
TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::half_t>());
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, F16_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<ck::half_t>());
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, F16_oddC) { EXPECT_TRUE(this->test_oddC<ck::half_t>()); }
|
||||
TEST_F(Conv2dFwdNHWCInstances, F32_default) { EXPECT_TRUE(this->test_default<float>()); }
|
||||
TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<float>());
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, F32_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<float>());
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, I8_default) { EXPECT_TRUE(this->test_default<int8_t>()); }
|
||||
TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<int8_t>());
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, I8_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<int8_t>());
|
||||
}
|
||||
|
||||
TEST(Conv2DFwdNHWC, BF32Instances)
|
||||
TEST_F(Conv2dFwdNHWCInstances, ND_BF16_default)
|
||||
{
|
||||
EXPECT_TRUE(test_conv2d_nhwc_instances<float>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<float, float, float>::Get<2>()));
|
||||
EXPECT_TRUE(this->test_default<ck::bhalf_t>(true));
|
||||
}
|
||||
|
||||
TEST(Conv2DFwdNHWC, F32Instances)
|
||||
TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_conv2d_nhwc_instances<float>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<float, float, float>::Get<2>()));
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::bhalf_t>(true));
|
||||
}
|
||||
|
||||
TEST(Conv2DFwdNHWC, Int8Instances)
|
||||
TEST_F(Conv2dFwdNHWCInstances, ND_BF16_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_conv2d_nhwc_instances<int8_t>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<int8_t, int8_t, int8_t>::Get<2>()));
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<ck::bhalf_t>(true));
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, ND_F16_default)
|
||||
{
|
||||
EXPECT_TRUE(this->test_default<ck::half_t>(true));
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::half_t>(true));
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, ND_F16_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<ck::half_t>(true));
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, ND_F32_default) { EXPECT_TRUE(this->test_default<float>(true)); }
|
||||
TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<float>(true));
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, ND_F32_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<float>(true));
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, ND_I8_default) { EXPECT_TRUE(this->test_default<int8_t>(true)); }
|
||||
TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<int8_t>(true));
|
||||
}
|
||||
TEST_F(Conv2dFwdNHWCInstances, ND_I8_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<int8_t>(true));
|
||||
}
|
||||
|
||||
@@ -12,61 +12,143 @@
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename T>
|
||||
bool test_conv3d_ndhwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs)
|
||||
class Conv3dFwdNDHWCInstances : public ::testing::Test
|
||||
{
|
||||
public:
|
||||
template <typename T>
|
||||
bool test_conv3d_nwc_instances(const std::vector<test::conv::DeviceConvFwdNoOpPtr>& conv_ptrs,
|
||||
const ck::utils::conv::ConvParams& params)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
|
||||
conv::ConvFwdOpInstance<T,
|
||||
T,
|
||||
T,
|
||||
ctl::NDHWC,
|
||||
ctl::KZYXC,
|
||||
ctl::NDHWK,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
FillUniformDistributionIntegerValue<T>,
|
||||
FillUniformDistributionIntegerValue<T>>
|
||||
conv_instance(params,
|
||||
true,
|
||||
FillUniformDistributionIntegerValue<T>{},
|
||||
FillUniformDistributionIntegerValue<T>{});
|
||||
auto reference_conv_fwd_fun =
|
||||
std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
run_engine.SetAtol(atol_);
|
||||
run_engine.SetRtol(rtol_);
|
||||
return run_engine.Test(conv_ptrs);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool test_default()
|
||||
{
|
||||
return test_conv3d_nwc_instances<T>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<3>(), params_default_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool test_filter1x1_stride1_pad0()
|
||||
{
|
||||
return test_conv3d_nwc_instances<T>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<3>(),
|
||||
params_filter1x1_stride1_pad0_);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool test_filter1x1_pad0()
|
||||
{
|
||||
return test_conv3d_nwc_instances<T>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<T, T, T>::template Get<3>(),
|
||||
params_filter1x1_pad0_);
|
||||
}
|
||||
|
||||
static inline ck::utils::conv::ConvParams params_default_{
|
||||
3, 4, 256, 64, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}};
|
||||
static inline ck::utils::conv::ConvParams params_filter1x1_stride1_pad0_{
|
||||
3, 4, 256, 64, {1, 1, 1}, {28, 28, 28}, {1, 1, 1}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
static inline ck::utils::conv::ConvParams params_filter1x1_pad0_{
|
||||
3, 4, 256, 64, {1, 1, 1}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {0, 0, 0}, {0, 0, 0}};
|
||||
|
||||
private:
|
||||
double atol_{1e-5};
|
||||
double rtol_{1e-4};
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Conv3DFwdNDHWC, IntegerValues)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using T = float;
|
||||
|
||||
conv::ConvParams params;
|
||||
params.N_ = 64;
|
||||
params.num_dim_spatial_ = 3;
|
||||
params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3, 2};
|
||||
params.input_spatial_lengths_ = std::vector<ck::index_t>{32, 32, 2};
|
||||
params.conv_filter_strides_ = std::vector<ck::index_t>{2, 2, 2};
|
||||
params.conv_filter_dilations_ = std::vector<ck::index_t>{1, 1, 1};
|
||||
params.input_left_pads_ = std::vector<ck::index_t>{1, 1, 1};
|
||||
params.input_right_pads_ = std::vector<ck::index_t>{1, 1, 1};
|
||||
ck::utils::conv::ConvParams params{
|
||||
3, 4, 256, 64, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}};
|
||||
|
||||
conv::ConvFwdOpInstance<T, T, T, ctl::NDHWC, ctl::KZYXC, ctl::NDHWK> conv_instance(params);
|
||||
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs);
|
||||
conv::ConvFwdOpInstance<T,
|
||||
T,
|
||||
T,
|
||||
ctl::NDHWC,
|
||||
ctl::KZYXC,
|
||||
ctl::NDHWK,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
FillUniformDistributionIntegerValue<T>,
|
||||
FillUniformDistributionIntegerValue<T>>
|
||||
conv_instance(params,
|
||||
true,
|
||||
FillUniformDistributionIntegerValue<T>{},
|
||||
FillUniformDistributionIntegerValue<T>{});
|
||||
|
||||
auto reference_conv_fwd_fun =
|
||||
std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
return run_engine.Test(conv_ptrs);
|
||||
run_engine.SetAtol(1e-5);
|
||||
run_engine.SetRtol(1e-3);
|
||||
EXPECT_TRUE(run_engine.Test(conv_ptrs));
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(Conv3DFwdNDHWC, TestConv3D)
|
||||
TEST(Conv3DFwdNDHWC, FloatingPointValues)
|
||||
{
|
||||
using namespace std::placeholders;
|
||||
using namespace ck::utils;
|
||||
namespace ctl = ck::tensor_layout::convolution;
|
||||
using T = ck::half_t;
|
||||
|
||||
conv::ConvParams params;
|
||||
params.num_dim_spatial_ = 3;
|
||||
params.N_ = 2;
|
||||
params.K_ = 16;
|
||||
params.C_ = 4;
|
||||
params.filter_spatial_lengths_ = std::vector<ck::index_t>{3, 3, 3};
|
||||
params.input_spatial_lengths_ = std::vector<ck::index_t>{16, 16, 16};
|
||||
params.conv_filter_strides_ = std::vector<ck::index_t>{1, 1, 1};
|
||||
params.conv_filter_dilations_ = std::vector<ck::index_t>{1, 1, 1};
|
||||
params.input_left_pads_ = std::vector<ck::index_t>{1, 1, 1};
|
||||
params.input_right_pads_ = std::vector<ck::index_t>{1, 1, 1};
|
||||
ck::utils::conv::ConvParams params{
|
||||
3, 4, 256, 64, {3, 3, 3}, {18, 18, 18}, {1, 1, 1}, {2, 2, 2}, {2, 2, 2}, {2, 2, 2}};
|
||||
|
||||
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs);
|
||||
conv::ConvFwdOpInstance<float, float, float, ctl::NDHWC, ctl::KZYXC, ctl::NDHWK> conv_instance(
|
||||
params);
|
||||
test::conv::get_test_convolution_fwd_instance<3, T, T, T, float>(conv_ptrs);
|
||||
conv::ConvFwdOpInstance<T,
|
||||
T,
|
||||
T,
|
||||
ctl::NDHWC,
|
||||
ctl::KZYXC,
|
||||
ctl::NDHWK,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
ck::tensor_operation::element_wise::PassThrough,
|
||||
FillUniformDistribution<T>,
|
||||
FillUniformDistribution<T>>
|
||||
conv_instance(params, true, FillUniformDistribution<T>{}, FillUniformDistribution<T>{});
|
||||
|
||||
auto reference_conv_fwd_fun = std::bind(
|
||||
conv::run_reference_convolution_forward<3, float, float, float>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<float, float, float> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
run_engine.SetAtol(1e-5);
|
||||
run_engine.SetRtol(1e-4);
|
||||
auto reference_conv_fwd_fun =
|
||||
std::bind(conv::run_reference_convolution_forward<3, T, T, T>, params, _1, _2, _3);
|
||||
OpInstanceRunEngine<T, T, T> run_engine(conv_instance, reference_conv_fwd_fun);
|
||||
run_engine.SetAtol(1e-3);
|
||||
run_engine.SetRtol(1e-3);
|
||||
EXPECT_TRUE(run_engine.Test(conv_ptrs));
|
||||
}
|
||||
|
||||
@@ -74,6 +156,7 @@ TEST(Conv3DFwdNDHWC, InputOver2GB)
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using namespace ck::utils;
|
||||
using T = float;
|
||||
|
||||
// >2GB Input
|
||||
conv::ConvParams params;
|
||||
@@ -89,8 +172,7 @@ TEST(Conv3DFwdNDHWC, InputOver2GB)
|
||||
params.input_right_pads_ = std::vector<ck::index_t>{1, 1, 1};
|
||||
|
||||
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs);
|
||||
|
||||
test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs);
|
||||
auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -114,6 +196,7 @@ TEST(Conv3DFwdNDHWC, FiltersOver2GB)
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using namespace ck::utils;
|
||||
using T = float;
|
||||
|
||||
// >2GB Filters
|
||||
conv::ConvParams params;
|
||||
@@ -129,8 +212,7 @@ TEST(Conv3DFwdNDHWC, FiltersOver2GB)
|
||||
params.input_right_pads_ = std::vector<ck::index_t>{1, 1, 1};
|
||||
|
||||
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs);
|
||||
|
||||
test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs);
|
||||
auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -154,6 +236,7 @@ TEST(Conv3DFwdNDHWC, OutputOver2GB)
|
||||
{
|
||||
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
|
||||
using namespace ck::utils;
|
||||
using T = float;
|
||||
|
||||
// >2GB Output
|
||||
conv::ConvParams params;
|
||||
@@ -169,7 +252,7 @@ TEST(Conv3DFwdNDHWC, OutputOver2GB)
|
||||
params.input_right_pads_ = std::vector<ck::index_t>{2, 2, 2};
|
||||
|
||||
std::vector<test::conv::DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
test::conv::get_test_convolution_fwd_instance<3>(conv_ptrs);
|
||||
test::conv::get_test_convolution_fwd_instance<3, T, T, T, T>(conv_ptrs);
|
||||
auto arg = conv_ptrs.back()->MakeArgumentPointer(nullptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
@@ -189,26 +272,42 @@ TEST(Conv3DFwdNDHWC, OutputOver2GB)
|
||||
EXPECT_FALSE(conv_ptrs.back()->IsSupportedArgument(arg.get()));
|
||||
}
|
||||
|
||||
TEST(Conv3DFwdNDHWC, Bf16Instances)
|
||||
TEST_F(Conv3dFwdNDHWCInstances, BF16_default) { EXPECT_TRUE(this->test_default<ck::bhalf_t>()); }
|
||||
TEST_F(Conv3dFwdNDHWCInstances, BF16_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_conv3d_ndhwc_instances<ck::bhalf_t>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>::Get<3>()));
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::bhalf_t>());
|
||||
}
|
||||
TEST_F(Conv3dFwdNDHWCInstances, BF16_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<ck::bhalf_t>());
|
||||
}
|
||||
|
||||
TEST(Conv3DFwdNDHWC, F16Instances)
|
||||
TEST_F(Conv3dFwdNDHWCInstances, F16_default) { EXPECT_TRUE(this->test_default<ck::half_t>()); }
|
||||
TEST_F(Conv3dFwdNDHWCInstances, F16_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_conv3d_ndhwc_instances<ck::half_t>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<ck::half_t, ck::half_t, ck::half_t>::Get<3>()));
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<ck::half_t>());
|
||||
}
|
||||
TEST_F(Conv3dFwdNDHWCInstances, F16_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<ck::half_t>());
|
||||
}
|
||||
|
||||
TEST(Conv3DFwdNDHWC, F32Instances)
|
||||
TEST_F(Conv3dFwdNDHWCInstances, F32_default) { EXPECT_TRUE(this->test_default<float>()); }
|
||||
TEST_F(Conv3dFwdNDHWCInstances, F32_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_conv3d_ndhwc_instances<float>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<float, float, float>::Get<3>()));
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<float>());
|
||||
}
|
||||
TEST_F(Conv3dFwdNDHWCInstances, F32_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<float>());
|
||||
}
|
||||
|
||||
TEST(Conv3DFwdNDHWC, Int8Instances)
|
||||
TEST_F(Conv3dFwdNDHWCInstances, I8_default) { EXPECT_TRUE(this->test_default<int8_t>()); }
|
||||
TEST_F(Conv3dFwdNDHWCInstances, I8_filter1x1_stride1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(test_conv3d_ndhwc_instances<int8_t>(
|
||||
ck::utils::conv::ConvolutionFwdInstances<int8_t, int8_t, int8_t>::Get<3>()));
|
||||
EXPECT_TRUE(this->test_filter1x1_stride1_pad0<int8_t>());
|
||||
}
|
||||
TEST_F(Conv3dFwdNDHWCInstances, I8_filter1x1_pad0)
|
||||
{
|
||||
EXPECT_TRUE(this->test_filter1x1_pad0<int8_t>());
|
||||
}
|
||||
|
||||
@@ -1,14 +1,33 @@
|
||||
#ifndef TEST_CONV_UTIL_HPP
|
||||
#define TEST_CONV_UTIL_HPP
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
|
||||
#include "config.hpp"
|
||||
#include "data_type.hpp"
|
||||
#include "device_convnd_fwd_xdl_nhwc_kyxc_nhwk.hpp"
|
||||
#include "element_wise_operation.hpp"
|
||||
#include "host_tensor.hpp"
|
||||
#include "sequence.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace tensor_operation {
|
||||
namespace device {
|
||||
|
||||
using DeviceConvFwdNoOpPtr = DeviceConvFwdPtr<element_wise::PassThrough,
|
||||
element_wise::PassThrough,
|
||||
element_wise::PassThrough>;
|
||||
namespace device_conv2d_fwd_instance {
|
||||
|
||||
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
void add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(std::vector<DeviceConvFwdNoOpPtr>&);
|
||||
|
||||
} // namespace device_conv2d_fwd_instance
|
||||
} // namespace device
|
||||
} // namespace tensor_operation
|
||||
} // namespace ck
|
||||
|
||||
namespace test {
|
||||
namespace conv {
|
||||
|
||||
@@ -25,57 +44,128 @@ using DeviceConvFwdNoOpPtr =
|
||||
static constexpr auto ConvFwdDefault =
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default;
|
||||
|
||||
template <ck::index_t SpatialDims, typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
template <ck::index_t SpatialDims,
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType>
|
||||
using DeviceConvNDFwdInstance = ck::tensor_operation::device::
|
||||
DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K<
|
||||
// clang-format off
|
||||
InDataType, //
|
||||
WeiDataType, //
|
||||
OutDataType, //
|
||||
InDataType, //
|
||||
AccDataType, // Accumulator data type.
|
||||
InElementOp, // Input Elementwise Operation
|
||||
WeiElementOp, // Weights Elementwise Operation
|
||||
OutElementOp, // Output Elementwise Operation
|
||||
ConvFwdDefault, // ConvForwardSpecialization
|
||||
SpatialDims, // SptialDims
|
||||
64, // BlockSize
|
||||
16, // MPerBlock
|
||||
16, // NPerBlock
|
||||
256, // BlockSize
|
||||
128, // MPerBlock
|
||||
256, // NPerBlock
|
||||
4, // K0PerBlock
|
||||
1, // K1
|
||||
16, // MPerXDL
|
||||
16, // NPerXDL
|
||||
1, // MXdlPerWave
|
||||
1, // NXdlPerWave
|
||||
S<1, 16, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
8, // K1
|
||||
32, // MPerXdl
|
||||
32, // NPerXdl
|
||||
2, // MXdlPerWave
|
||||
4, // NXdlPerWave
|
||||
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_K0_M_K1
|
||||
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
|
||||
2, // ABlockTransferSrcVectorDim
|
||||
1, // ABlockTransferSrcScalarPerVector
|
||||
1, // ABlockTransferDstScalarPerVector_K1
|
||||
8, // ABlockTransferSrcScalarPerVector
|
||||
8, // ABlockTransferDstScalarPerVector_K1
|
||||
true, // ABlockLdsAddExtraM
|
||||
S<1, 16, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_K0_N_K1
|
||||
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
|
||||
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
|
||||
2, // BBlockTransferSrcVectorDim
|
||||
1, // BBlockTransferSrcScalarPerVector
|
||||
1, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockTransferAddExtraN
|
||||
8, // BBlockTransferSrcScalarPerVector
|
||||
8, // BBlockTransferDstScalarPerVector_K1
|
||||
true, // BBlockLdsAddExtraN
|
||||
7, // CThreadTransferSrcDstVectorDim
|
||||
1>; // CThreadTransferDstScalarPerVector
|
||||
1>; // CThreadTransferDstScalarPerVector
|
||||
// clang-format on
|
||||
|
||||
template <ck::index_t NDim,
|
||||
typename InDataType = float,
|
||||
typename WeiDataType = float,
|
||||
typename OutDataType = float>
|
||||
typename InDataType,
|
||||
typename WeiDataType,
|
||||
typename OutDataType,
|
||||
typename AccDataType>
|
||||
void get_test_convolution_fwd_instance(std::vector<DeviceConvFwdNoOpPtr>& instances)
|
||||
{
|
||||
using ConvInstanceT = DeviceConvNDFwdInstance<NDim, InDataType, WeiDataType, OutDataType>;
|
||||
using ConvInstanceT =
|
||||
DeviceConvNDFwdInstance<NDim, InDataType, WeiDataType, OutDataType, AccDataType>;
|
||||
instances.emplace_back(std::make_unique<ConvInstanceT>());
|
||||
}
|
||||
|
||||
// TODO (aosewski)
|
||||
// Temporary solution to get all DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
|
||||
// instances. When switched over to DeviceConvNDFwdXdl for 2D remove ConvolutionNDFwdInstances
|
||||
// structures.
|
||||
template <typename InDataType, typename WeiDataType, typename OutDataType>
|
||||
struct ConvolutionNDFwdInstances;
|
||||
|
||||
template <>
|
||||
struct ConvolutionNDFwdInstances<float, float, float>
|
||||
{
|
||||
static std::vector<DeviceConvFwdNoOpPtr> Get(std::size_t num_dim_spatial)
|
||||
{
|
||||
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
if(num_dim_spatial == 2)
|
||||
{
|
||||
ck::tensor_operation::device::device_conv2d_fwd_instance::
|
||||
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f32_instances(conv_ptrs);
|
||||
}
|
||||
return conv_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvolutionNDFwdInstances<ck::half_t, ck::half_t, ck::half_t>
|
||||
{
|
||||
static std::vector<DeviceConvFwdNoOpPtr> Get(std::size_t num_dim_spatial)
|
||||
{
|
||||
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
if(num_dim_spatial == 2)
|
||||
{
|
||||
ck::tensor_operation::device::device_conv2d_fwd_instance::
|
||||
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_f16_instances(conv_ptrs);
|
||||
}
|
||||
return conv_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvolutionNDFwdInstances<ck::bhalf_t, ck::bhalf_t, ck::bhalf_t>
|
||||
{
|
||||
static std::vector<DeviceConvFwdNoOpPtr> Get(std::size_t num_dim_spatial)
|
||||
{
|
||||
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
if(num_dim_spatial == 2)
|
||||
{
|
||||
ck::tensor_operation::device::device_conv2d_fwd_instance::
|
||||
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_bf16_instances(conv_ptrs);
|
||||
}
|
||||
return conv_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct ConvolutionNDFwdInstances<int8_t, int8_t, int8_t>
|
||||
{
|
||||
static std::vector<DeviceConvFwdNoOpPtr> Get(std::size_t num_dim_spatial)
|
||||
{
|
||||
std::vector<DeviceConvFwdNoOpPtr> conv_ptrs;
|
||||
if(num_dim_spatial == 2)
|
||||
{
|
||||
ck::tensor_operation::device::device_conv2d_fwd_instance::
|
||||
add_device_convnd_2d_fwd_xdl_nhwc_kyxc_nhwk_int8_instances(conv_ptrs);
|
||||
}
|
||||
return conv_ptrs;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace conv
|
||||
} // namespace test
|
||||
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user