mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 02:02:46 +00:00
* add DeviceGemmXdl
* update script
* fix naming issue
* fix comment
* output HostTensorDescriptor
* rename
* padded GEMM for fwd v4r4r4 nhwc
* refactor
* refactor
* refactor
* adding ckProfiler
* adding ckProfiler
* refactor
* fix tuning parameter bug
* add more gemm instances
* add more fp16 GEMM instances
* fix profiler driver
* fix bug in tuning parameter
* add fp32 gemm instances
* small fix
* refactor
* rename
* refactor gemm profiler; adding DeviceConv and conv profiler
* refactor
* fix
* add conv profiler
* refactor
* adding more GEMM and Conv instance
* Create README.md
Add build instruction for ckProfiler
* Create README.md
Add Readme for gemm_xdl example
* Update README.md
Remove build instruction from top most folder
* Update README.md
* clean up
[ROCm/composable_kernel commit: e823d518cb]
51 lines
1.9 KiB
C++
51 lines
1.9 KiB
C++
#pragma once
|
|
#include "host_tensor.hpp"
|
|
#include "conv_common.hpp"
|
|
|
|
template <typename TIn,
|
|
typename TWei,
|
|
typename TOut,
|
|
typename ConvStrides,
|
|
typename ConvDilations,
|
|
typename InLeftPads,
|
|
typename InRightPads>
|
|
void host_conv_nchw_kcyx_nkhw(const Tensor<TIn>& in,
|
|
const Tensor<TWei>& wei,
|
|
Tensor<TOut>& out,
|
|
const ConvStrides& conv_strides,
|
|
const ConvDilations& conv_dilations,
|
|
const InLeftPads& in_left_pads,
|
|
const InRightPads&)
|
|
{
|
|
constexpr auto I0 = ck::Number<0>{};
|
|
constexpr auto I1 = ck::Number<1>{};
|
|
|
|
auto f_nchw = [&](auto n, auto k, auto ho, auto wo) {
|
|
double v = 0;
|
|
for(int c = 0; c < wei.mDesc.GetLengths()[1]; ++c)
|
|
{
|
|
for(int y = 0; y < wei.mDesc.GetLengths()[2]; ++y)
|
|
{
|
|
int hi = ho * conv_strides[I0] + y * conv_dilations[I0] - in_left_pads[I0];
|
|
for(int x = 0; x < wei.mDesc.GetLengths()[3]; ++x)
|
|
{
|
|
int wi = wo * conv_strides[I1] + x * conv_dilations[I1] - in_left_pads[I1];
|
|
if(hi >= 0 && hi < in.mDesc.GetLengths()[2] && wi >= 0 &&
|
|
wi < in.mDesc.GetLengths()[3])
|
|
{
|
|
v += static_cast<const double>(in(n, c, hi, wi)) *
|
|
static_cast<const double>(wei(k, c, y, x));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
out(n, k, ho, wo) = v;
|
|
};
|
|
|
|
make_ParallelTensorFunctor(f_nchw,
|
|
out.mDesc.GetLengths()[0],
|
|
out.mDesc.GetLengths()[1],
|
|
out.mDesc.GetLengths()[2],
|
|
out.mDesc.GetLengths()[3])(std::thread::hardware_concurrency());
|
|
}
|