Add support for GKCYX grouped conv fwd (#2015)

* Add support for GKCYX grouped conv fwd

* fixes

* fix

* changelog

* Fixes

[ROCm/composable_kernel commit: 54c81a1fcf]
This commit is contained in:
Bartłomiej Kocot
2025-03-26 21:13:38 +01:00
committed by GitHub
parent ba16351a03
commit f967fd7296
39 changed files with 1005 additions and 570 deletions

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <numeric>
@@ -16,6 +16,7 @@ enum struct ConvLayout
GNHWC_GKYXC_GNHWK, // 0
NHWGC_GKYXC_NHWGK, // 1
NGCHW_GKYXC_NGKHW, // 2
NGCHW_GKCYX_NGKHW, // 3
};
enum struct ConvDataType
@@ -52,11 +53,13 @@ static void print_helper_msg()
<< " 5: Input bf8, Weight bf8, Output fp8\n"
<< " 6: Input fp8, Weight bf8, Output fp8\n"
<< " 7: Input bf8, Weight fp8, Output fp8)\n"
<< "arg3: indexing data type (0: 32-bit, 1: 64-bit)\n"
<< "arg4: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K])\n"
<< "arg3: tensor layout (0: Input[G, N, Hi, Wi, C], Weight[G, K, Y, X, C], Output[G, N, Ho, Wo, K]\n"
<< " 1: Input[N, Hi, Wi, G, C], Weight[G, K, Y, X, C], Output[N, Ho, Wo, G, K]\n"
<< " 2: Input[N, G, C, Hi, Wi], Weight[G, K, Y, X, C], Output[N, "
"G, K, Ho, Wo]\n"
"G, K, Ho, Wo]\n"
<< " 3: Input[N, G, C, Hi, Wi], Weight[G, K, C, Y, X], Output[N, "
"G, K, Ho, Wo])\n"
<< "arg4: indexing data type (0: 32-bit, 1: 64-bit)\n"
<< "arg5: verification (0: no, 1: yes)\n"
<< "arg6: initialization (0: no init, 1: integer value, 2: decimal value)\n"
<< "arg7: print tensor value (0: no; 1: yes)\n"
@@ -110,6 +113,10 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
using GKYXC = ck::tensor_layout::convolution::GKYXC;
using GKZYXC = ck::tensor_layout::convolution::GKZYXC;
// using GKCX = ck::tensor_layout::convolution::GKXC;
using GKCYX = ck::tensor_layout::convolution::GKCYX;
// using GKCZYX = ck::tensor_layout::convolution::GKZYXC;
using GNWK = ck::tensor_layout::convolution::GNWK;
using GNHWK = ck::tensor_layout::convolution::GNHWK;
using GNDHWK = ck::tensor_layout::convolution::GNDHWK;
@@ -302,6 +309,25 @@ int profile_grouped_conv_fwd(int argc, char* argv[])
{
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NGCHW{}, GKYXC{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
}
else if(num_dim_spatial == 2 && layout == ConvLayout::NGCHW_GKCYX_NGKHW)
{
if(data_type == ConvDataType::F32_F32_F32)
{
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F32{}, F32{}, F32{}, F32{}, F32{});
}
else if(data_type == ConvDataType::F16_F16_F16)
{
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, F16{}, F16{}, F16{}, F16{}, F16{});
}
else if(data_type == ConvDataType::BF16_BF16_BF16)
{
return profile(I2, NGCHW{}, GKCYX{}, NGKHW{}, BF16{}, BF16{}, BF16{}, BF16{}, BF16{});
}
}
else if(num_dim_spatial == 3 && layout == ConvLayout::NHWGC_GKYXC_NHWGK)
{