mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK Tile] Fix example for conv fwd + bias + clamp (#3235)
* Fix clamp not being applied correctly * Apply group offsets to D tensors --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -184,11 +184,6 @@ int run_grouped_conv_fwd_bias_clamp_example_with_layouts(
|
||||
|
||||
if(arg_parser.get_int("v") == 1)
|
||||
{
|
||||
// FIXME: Address this issue
|
||||
if(arg_parser.get_int("g") > 1 && init_method == 0)
|
||||
std::cerr << "Adding different bias to different groups yield incorrect results"
|
||||
<< std::endl;
|
||||
|
||||
ck_tile::HostTensor<OutDataType> output_host_ref(out_g_n_k_wos_desc);
|
||||
output_host_ref.SetZero();
|
||||
|
||||
@@ -250,29 +245,27 @@ template <typename Invoker,
|
||||
int run_grouped_conv_fwd_bias_clamp_example_prec_type(
|
||||
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
|
||||
{
|
||||
// using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
|
||||
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
|
||||
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
|
||||
|
||||
// using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
|
||||
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
|
||||
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
|
||||
|
||||
// using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
|
||||
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
|
||||
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
|
||||
|
||||
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
|
||||
{
|
||||
// FIXME: Fix crash in 1D convolution whem using Ds tensor.
|
||||
throw std::runtime_error("1D Convolution does not support bias.");
|
||||
// return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<1>{},
|
||||
// GemmWarpConfig,
|
||||
// Invoker,
|
||||
// InPrecType,
|
||||
// WeiPrecType,
|
||||
// OutPrecType>(
|
||||
// argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<1>{},
|
||||
GemmWarpConfig,
|
||||
Invoker,
|
||||
InPrecType,
|
||||
WeiPrecType,
|
||||
OutPrecType>(
|
||||
argc, argv, NWGC{}, GKXC{}, NWGK{});
|
||||
}
|
||||
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user