[CK Tile] Fix example for conv fwd + bias + clamp (#3235)

* Fix clamp not being applied correctly

* Apply group offsets to D tensors

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
Johannes Graner
2025-11-24 07:36:26 +01:00
committed by GitHub
parent f6c999bddb
commit 096f0a3b23
2 changed files with 36 additions and 27 deletions

View File

@@ -184,11 +184,6 @@ int run_grouped_conv_fwd_bias_clamp_example_with_layouts(
if(arg_parser.get_int("v") == 1)
{
// FIXME: Address this issue
if(arg_parser.get_int("g") > 1 && init_method == 0)
std::cerr << "Adding different bias to different groups yield incorrect results"
<< std::endl;
ck_tile::HostTensor<OutDataType> output_host_ref(out_g_n_k_wos_desc);
output_host_ref.SetZero();
@@ -250,29 +245,27 @@ template <typename Invoker,
int run_grouped_conv_fwd_bias_clamp_example_prec_type(
std::string in_layout, std::string wei_layout, std::string out_layout, int argc, char* argv[])
{
// using NWGC = ck_tile::tensor_layout::convolution::NWGC;
using NWGC = ck_tile::tensor_layout::convolution::NWGC;
using NHWGC = ck_tile::tensor_layout::convolution::NHWGC;
using NDHWGC = ck_tile::tensor_layout::convolution::NDHWGC;
// using GKXC = ck_tile::tensor_layout::convolution::GKXC;
using GKXC = ck_tile::tensor_layout::convolution::GKXC;
using GKYXC = ck_tile::tensor_layout::convolution::GKYXC;
using GKZYXC = ck_tile::tensor_layout::convolution::GKZYXC;
// using NWGK = ck_tile::tensor_layout::convolution::NWGK;
using NWGK = ck_tile::tensor_layout::convolution::NWGK;
using NHWGK = ck_tile::tensor_layout::convolution::NHWGK;
using NDHWGK = ck_tile::tensor_layout::convolution::NDHWGK;
if(in_layout == "NWGC" && wei_layout == "GKXC" && out_layout == "NWGK")
{
// FIXME: Fix crash in 1D convolution whem using Ds tensor.
throw std::runtime_error("1D Convolution does not support bias.");
// return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<1>{},
// GemmWarpConfig,
// Invoker,
// InPrecType,
// WeiPrecType,
// OutPrecType>(
// argc, argv, NWGC{}, GKXC{}, NWGK{});
return run_grouped_conv_fwd_bias_clamp_example_with_layouts<ck_tile::number<1>{},
GemmWarpConfig,
Invoker,
InPrecType,
WeiPrecType,
OutPrecType>(
argc, argv, NWGC{}, GKXC{}, NWGK{});
}
else if(in_layout == "NHWGC" && wei_layout == "GKYXC" && out_layout == "NHWGK")
{