mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
added strides and dilations suppport to implicit gemm v4
This commit is contained in:
@@ -103,10 +103,18 @@ auto make_TensorDescriptor(TConstTensorDesc)
|
||||
return TensorDescriptor(lengths, strides);
|
||||
}
|
||||
|
||||
template <class TIn, class TWei, class TOut, class LowerPads, class UpperPads>
|
||||
template <class TIn,
|
||||
class TWei,
|
||||
class TOut,
|
||||
class ConvStrides,
|
||||
class ConvDilations,
|
||||
class LowerPads,
|
||||
class UpperPads>
|
||||
void host_direct_convolution(const Tensor<TIn>& in_nchw,
|
||||
const Tensor<TWei>& wei_kcyx,
|
||||
Tensor<TOut>& out_nkhw,
|
||||
ConvStrides,
|
||||
ConvDilations,
|
||||
LowerPads,
|
||||
UpperPads)
|
||||
{
|
||||
@@ -122,10 +130,10 @@ void host_direct_convolution(const Tensor<TIn>& in_nchw,
|
||||
{
|
||||
for(int y = 0; y < wei_kcyx.mDesc.GetLengths()[2]; ++y)
|
||||
{
|
||||
int hi = ho + y - h_pad_low;
|
||||
int hi = ho * ConvStrides{}[0] + y * ConvDilations{}[0] - h_pad_low;
|
||||
for(int x = 0; x < wei_kcyx.mDesc.GetLengths()[3]; ++x)
|
||||
{
|
||||
int wi = wo + x - w_pad_low;
|
||||
int wi = wo * ConvStrides{}[1] + x * ConvDilations{}[1] - w_pad_low;
|
||||
if(hi >= 0 && hi < in_nchw.mDesc.GetLengths()[2] && wi >= 0 &&
|
||||
wi < in_nchw.mDesc.GetLengths()[3])
|
||||
{
|
||||
@@ -419,9 +427,9 @@ int main(int argc, char* argv[])
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 1
|
||||
#elif 0
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
constexpr index_t N = 128;
|
||||
constexpr index_t C = 256;
|
||||
constexpr index_t HI = 34;
|
||||
constexpr index_t WI = 34;
|
||||
@@ -429,6 +437,9 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
@@ -453,6 +464,9 @@ int main(int argc, char* argv[])
|
||||
constexpr index_t Y = 3;
|
||||
constexpr index_t X = 3;
|
||||
|
||||
using ConvStrides = Sequence<2, 2>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
constexpr index_t HPad = 0;
|
||||
constexpr index_t WPad = 0;
|
||||
#elif 0
|
||||
@@ -583,7 +597,7 @@ int main(int argc, char* argv[])
|
||||
auto in_nchw_desc = make_ConstantTensorDescriptor_packed(Sequence<N, C, HI, WI>{});
|
||||
auto wei_kcyx_desc = make_ConstantTensorDescriptor_packed(Sequence<K, C, Y, X>{});
|
||||
auto out_nkhw_desc = get_convolution_with_padding_output_default_4d_tensor_descriptor(
|
||||
in_nchw_desc, wei_kcyx_desc, lower_pads, upper_pads);
|
||||
in_nchw_desc, wei_kcyx_desc, ConvStrides{}, ConvDilations{}, lower_pads, upper_pads);
|
||||
|
||||
ostream_ConstantTensorDescriptor(in_nchw_desc, std::cout << "in_nchw_desc: ");
|
||||
ostream_ConstantTensorDescriptor(wei_kcyx_desc, std::cout << "wei_kcyx_desc: ");
|
||||
@@ -645,9 +659,17 @@ int main(int argc, char* argv[])
|
||||
#elif 1
|
||||
device_convolution_implicit_gemm_v4_nchw_kcyx_nkhw
|
||||
#endif
|
||||
(in_nchw_desc, in_nchw, wei_kcyx_desc, wei_kcyx, out_nkhw_desc, out_nkhw_device, nrepeat);
|
||||
(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
wei_kcyx,
|
||||
out_nkhw_desc,
|
||||
out_nkhw_device,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
nrepeat);
|
||||
|
||||
#elif 1
|
||||
#elif 0
|
||||
device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded(in_nchw_desc,
|
||||
in_nchw,
|
||||
wei_kcyx_desc,
|
||||
@@ -662,14 +684,21 @@ int main(int argc, char* argv[])
|
||||
if(do_verification)
|
||||
{
|
||||
#if 1
|
||||
if(Y == 3 && X == 3)
|
||||
if(Y == 3 && X == 3 && ConvStrides{}[0] == 1 && ConvStrides{}[1] == 1 &&
|
||||
ConvDilations{}[0] == 1 && ConvDilations{}[1] == 1)
|
||||
{
|
||||
host_winograd_3x3_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
|
||||
}
|
||||
else
|
||||
#endif
|
||||
{
|
||||
host_direct_convolution(in_nchw, wei_kcyx, out_nkhw_host, lower_pads, upper_pads);
|
||||
host_direct_convolution(in_nchw,
|
||||
wei_kcyx,
|
||||
out_nkhw_host,
|
||||
ConvStrides{},
|
||||
ConvDilations{},
|
||||
lower_pads,
|
||||
upper_pads);
|
||||
}
|
||||
check_error(out_nkhw_host, out_nkhw_device);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user