mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-20 14:59:17 +00:00
Refactor transform conv to gemm fwd (#1391)
* Refactor transform conv to gemm fwd * fixes codegen * wmma fixes * fix wmma * Fix copyright
This commit is contained in:
@@ -1,3 +1,6 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck/utility/common_header.hpp"
|
||||
@@ -95,16 +98,27 @@ auto transform_conv(ck::index_t num_dim,
|
||||
ck::Array<ck::index_t, 5> out_lengths,
|
||||
ck::Array<ck::index_t, 5> out_strides)
|
||||
{
|
||||
ck::Array<ck::index_t, 5> dummy_dims;
|
||||
ck::Array<ck::index_t, 2> dummy_spatial_dims;
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd;
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
@@ -112,10 +126,19 @@ auto transform_conv(ck::index_t num_dim,
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd;
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
@@ -123,20 +146,38 @@ auto transform_conv(ck::index_t num_dim,
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd;
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 2 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
2,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd;
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec");
|
||||
}
|
||||
@@ -146,16 +187,28 @@ auto transform_conv_3d(ck::index_t num_dim,
|
||||
ck::Array<ck::index_t, 6> out_lengths,
|
||||
ck::Array<ck::index_t, 6> out_strides)
|
||||
{
|
||||
ck::Array<ck::index_t, 6> dummy_dims;
|
||||
ck::Array<ck::index_t, 3> dummy_spatial_dims;
|
||||
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd;
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
@@ -163,10 +216,19 @@ auto transform_conv_3d(ck::index_t num_dim,
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd;
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
@@ -174,20 +236,38 @@ auto transform_conv_3d(ck::index_t num_dim,
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd;
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 3 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
3,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd;
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect conv spec");
|
||||
}
|
||||
@@ -197,16 +277,28 @@ auto transform_conv_1d(ck::index_t num_dim,
|
||||
ck::Array<ck::index_t, 4> out_lengths,
|
||||
ck::Array<ck::index_t, 4> out_strides)
|
||||
{
|
||||
ck::Array<ck::index_t, 4> dummy_dims;
|
||||
ck::Array<ck::index_t, 1> dummy_spatial_dims;
|
||||
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Default)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Default>
|
||||
conv_fwd;
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0)
|
||||
@@ -214,10 +306,19 @@ auto transform_conv_1d(ck::index_t num_dim,
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0>
|
||||
conv_fwd;
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 &&
|
||||
spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0)
|
||||
@@ -225,20 +326,38 @@ auto transform_conv_1d(ck::index_t num_dim,
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0>
|
||||
conv_fwd;
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
if(num_dim == 1 && spec == ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC)
|
||||
{
|
||||
ck::tensor_operation::TransformConvFwdToGemm<
|
||||
1,
|
||||
ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC>
|
||||
conv_fwd;
|
||||
conv_fwd{dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
dummy_dims,
|
||||
out_lengths,
|
||||
out_strides,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims,
|
||||
dummy_spatial_dims};
|
||||
|
||||
auto res = ck::tensor_operation::TransformConv();
|
||||
return res.transform_func(out_lengths, out_strides, conv_fwd);
|
||||
return res.transform_func(conv_fwd);
|
||||
}
|
||||
throw std::runtime_error("Incorrect dims or conv spec");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user