mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
Add wei_strides to grouped conv3d wei to keep consistency (#817)
* Add wei_strides to grouped conv3d wei to keep consistency * Fix strides in client examples * Unify backward weight api with forward * Fix for example * Fixes for examples --------- Co-authored-by: zjing14 <zhangjing14@gmail.com>
This commit is contained in:
@@ -32,63 +32,49 @@ struct SimpleDeviceMem
|
||||
};
|
||||
|
||||
template <ck::index_t NumDimSpatial>
|
||||
std::size_t GetFlops(ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
|
||||
std::size_t GetFlops(const std::array<ck::index_t, NumDimSpatial>& output_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
|
||||
{
|
||||
constexpr ck::index_t spatial_offset = 3;
|
||||
const auto C = filter_lengths[2];
|
||||
// 2 * G * N * K * C * <output spatial lengths product> * <filter spatial lengths product>
|
||||
return static_cast<std::size_t>(2) * G * N * K * C *
|
||||
std::accumulate(std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths),
|
||||
return static_cast<std::size_t>(2) * C *
|
||||
std::accumulate(std::begin(output_lengths),
|
||||
std::end(output_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>()) *
|
||||
std::accumulate(std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths),
|
||||
std::accumulate(std::begin(filter_lengths) + spatial_offset,
|
||||
std::end(filter_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>());
|
||||
}
|
||||
|
||||
template <typename InDataType, ck::index_t NumDimSpatial>
|
||||
std::size_t GetInputByte(ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths)
|
||||
std::size_t GetInputByte(const std::array<ck::index_t, NumDimSpatial>& input_lengths)
|
||||
{
|
||||
// sizeof(InDataType) * (G * N * C * <input spatial lengths product>) +
|
||||
return sizeof(InDataType) * (G * N * C *
|
||||
std::accumulate(std::begin(input_spatial_lengths),
|
||||
std::end(input_spatial_lengths),
|
||||
return sizeof(InDataType) * (std::accumulate(std::begin(input_lengths),
|
||||
std::end(input_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>()));
|
||||
}
|
||||
|
||||
template <typename WeiDataType, ck::index_t NumDimSpatial>
|
||||
std::size_t GetWeightByte(ck::index_t G,
|
||||
ck::index_t K,
|
||||
ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths)
|
||||
std::size_t GetWeightByte(const std::array<ck::index_t, NumDimSpatial>& filter_lengths)
|
||||
{
|
||||
// sizeof(WeiDataType) * (G * K * C * <filter spatial lengths product>) +
|
||||
return sizeof(WeiDataType) * (G * K * C *
|
||||
std::accumulate(std::begin(filter_spatial_lengths),
|
||||
std::end(filter_spatial_lengths),
|
||||
return sizeof(WeiDataType) * (std::accumulate(std::begin(filter_lengths),
|
||||
std::end(filter_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<>()));
|
||||
}
|
||||
|
||||
template <typename OutDataType, ck::index_t NumDimSpatial>
|
||||
std::size_t GetOutputByte(ck::index_t G,
|
||||
ck::index_t N,
|
||||
ck::index_t K,
|
||||
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths)
|
||||
std::size_t GetOutputByte(const std::array<ck::index_t, NumDimSpatial>& output_lengths)
|
||||
{
|
||||
// sizeof(OutDataType) * (G * N * K * <output spatial lengths product>);
|
||||
return sizeof(OutDataType) * (G * N * K *
|
||||
std::accumulate(std::begin(output_spatial_lengths),
|
||||
std::end(output_spatial_lengths),
|
||||
return sizeof(OutDataType) * (std::accumulate(std::begin(output_lengths),
|
||||
std::end(output_lengths),
|
||||
static_cast<std::size_t>(1),
|
||||
std::multiplies<std::size_t>()));
|
||||
}
|
||||
@@ -101,14 +87,11 @@ template <ck::index_t NumDimSpatial,
|
||||
typename WeiLayout,
|
||||
typename OutLayout>
|
||||
bool run_grouped_conv_bwd_weight(
|
||||
const ck::index_t G,
|
||||
const ck::index_t N,
|
||||
const ck::index_t K,
|
||||
const ck::index_t C,
|
||||
const std::array<ck::index_t, NumDimSpatial>& input_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& filter_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial>& output_spatial_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& input_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& input_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& filter_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& weights_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& output_lengths,
|
||||
const std::array<ck::index_t, NumDimSpatial + 3>& output_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial>& conv_filter_strides,
|
||||
const std::array<ck::index_t, NumDimSpatial>& conv_filter_dilations,
|
||||
@@ -117,9 +100,9 @@ bool run_grouped_conv_bwd_weight(
|
||||
{
|
||||
|
||||
ck::index_t split_k = 2;
|
||||
SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths));
|
||||
SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths));
|
||||
SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths));
|
||||
SimpleDeviceMem in(GetInputByte<InDataType, NumDimSpatial + 3>(input_lengths));
|
||||
SimpleDeviceMem wei(GetWeightByte<WeiDataType, NumDimSpatial + 3>(filter_lengths));
|
||||
SimpleDeviceMem out(GetOutputByte<OutDataType, NumDimSpatial + 3>(output_lengths));
|
||||
|
||||
using DeviceOp = ck::tensor_operation::device::DeviceGroupedConvBwdWeight<NumDimSpatial,
|
||||
InLayout,
|
||||
@@ -143,6 +126,10 @@ bool run_grouped_conv_bwd_weight(
|
||||
float best_gb_per_sec = 0;
|
||||
float best_tflops = 0;
|
||||
|
||||
std::array<ck::index_t, NumDimSpatial + 3> a_g_n_c_wis_lengths{};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> a_g_n_c_wis_strides{};
|
||||
std::array<ck::index_t, NumDimSpatial + 3> b_g_k_c_xs_lengths{};
|
||||
|
||||
// profile device operation instances
|
||||
std::cout << "Run all instances and do timing" << std::endl;
|
||||
|
||||
@@ -152,14 +139,11 @@ bool run_grouped_conv_bwd_weight(
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
out.GetDeviceBuffer(),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
@@ -176,12 +160,10 @@ bool run_grouped_conv_bwd_weight(
|
||||
{
|
||||
float avg_time = invoker_ptr->Run(argument_ptr.get(), StreamConfig{nullptr, true});
|
||||
|
||||
std::size_t flop =
|
||||
GetFlops<NumDimSpatial>(G, N, K, C, output_spatial_lengths, filter_spatial_lengths);
|
||||
std::size_t num_bytes =
|
||||
GetInputByte<InDataType, NumDimSpatial>(G, N, C, input_spatial_lengths) +
|
||||
GetWeightByte<WeiDataType, NumDimSpatial>(G, K, C, filter_spatial_lengths) +
|
||||
GetOutputByte<OutDataType, NumDimSpatial>(G, N, K, output_spatial_lengths);
|
||||
std::size_t flop = GetFlops<NumDimSpatial + 3>(output_lengths, filter_lengths);
|
||||
std::size_t num_bytes = GetInputByte<InDataType, NumDimSpatial + 3>(input_lengths) +
|
||||
GetWeightByte<WeiDataType, NumDimSpatial + 3>(filter_lengths) +
|
||||
GetOutputByte<OutDataType, NumDimSpatial + 3>(output_lengths);
|
||||
|
||||
float tflops = static_cast<float>(flop) / 1.E9 / avg_time;
|
||||
float gb_per_sec = num_bytes / 1.E6 / avg_time;
|
||||
@@ -221,14 +203,11 @@ bool run_grouped_conv_bwd_weight(
|
||||
auto argument_ptr = op_ptr->MakeArgumentPointer(in.GetDeviceBuffer(),
|
||||
wei.GetDeviceBuffer(),
|
||||
out.GetDeviceBuffer(),
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
|
||||
@@ -22,11 +22,12 @@ static constexpr ck::index_t C = 192;
|
||||
static constexpr ck::index_t X = 3;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, C, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, K, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{N * Wi * C, Wi* C, 1, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{K * X * C, X* C, 1, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{N * Wo * K, Wo* K, 1, K};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1};
|
||||
@@ -40,14 +41,11 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
OutLayout>(input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
|
||||
@@ -25,13 +25,15 @@ static constexpr ck::index_t Hi = 28;
|
||||
static constexpr ck::index_t Wi = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 28;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
|
||||
N * Hi * Wi * C, Hi* Wi* C, Wi* C, C, 1};
|
||||
N * Hi * Wi * C, Hi* Wi* C, 1, Wi* C, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
|
||||
K * Y * X * C, Y* X* C, 1, X* C, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
|
||||
N * Ho * Wo * K, Ho* Wo* K, Wo* K, K, 1};
|
||||
N * Ho * Wo * K, Ho* Wo* K, 1, Wo* K, K};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1};
|
||||
@@ -45,14 +47,11 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
OutLayout>(input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
|
||||
@@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Di, Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Z, Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Do, Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
|
||||
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
|
||||
N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
|
||||
K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
|
||||
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
|
||||
N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
|
||||
@@ -48,14 +50,11 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
input_spatial_lengths,
|
||||
filter_spatial_lengths,
|
||||
output_spatial_lengths,
|
||||
OutLayout>(input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
|
||||
@@ -28,13 +28,15 @@ static constexpr ck::index_t Wi = 3;
|
||||
static constexpr ck::index_t Do = 28;
|
||||
static constexpr ck::index_t Ho = 28;
|
||||
static constexpr ck::index_t Wo = 3;
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_spatial_lengths{Di, Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> filter_spatial_lengths{Z, Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> output_spatial_lengths{Do, Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_lengths{G, N, C, Di, Hi, Wi};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> filter_lengths{G, K, C, Z, Y, X};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_lengths{G, N, K, Do, Ho, Wo};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> input_strides{
|
||||
N * Di * Hi * Wi * C, Di* Hi* Wi* C, Hi* Wi* C, Wi* C, C, 1};
|
||||
N * Di * Hi * Wi * C, Di* Hi* Wi* C, 1, Hi* Wi* C, Wi* C, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> weights_strides{
|
||||
K * Z * Y * X * C, Z* Y* X* C, 1, Y* X* C, X* C, C};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial + 3> output_strides{
|
||||
N * Do * Ho * Wo * K, Do* Ho* Wo* K, Ho* Wo* K, Wo* K, K, 1};
|
||||
N * Do * Ho * Wo * K, Do* Ho* Wo* K, 1, Ho* Wo* K, Wo* K, K};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_strides{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> conv_filter_dilations{1, 1, 1};
|
||||
static constexpr std::array<ck::index_t, NumDimSpatial> input_left_pads{1, 1, 1};
|
||||
@@ -48,20 +50,16 @@ int main()
|
||||
OutDataType,
|
||||
InLayout,
|
||||
WeiLayout,
|
||||
OutLayout>(
|
||||
G,
|
||||
N,
|
||||
K,
|
||||
C,
|
||||
{Di, Hi, Wi},
|
||||
{Z, Y, X},
|
||||
{Do, Ho, Wo},
|
||||
{N * Di * Hi * Wi * C, Di * Hi * Wi * C, Hi * Wi * C, Wi * C, C, 1},
|
||||
{N * Do * Ho * Wo * K, Do * Ho * Wo * K, Ho * Wo * K, Wo * K, K, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1},
|
||||
{1, 1, 1})
|
||||
OutLayout>(input_lengths,
|
||||
input_strides,
|
||||
filter_lengths,
|
||||
weights_strides,
|
||||
output_lengths,
|
||||
output_strides,
|
||||
conv_filter_strides,
|
||||
conv_filter_dilations,
|
||||
input_left_pads,
|
||||
input_right_pads)
|
||||
? EXIT_SUCCESS
|
||||
: EXIT_FAILURE;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user