Change grouped conv fwd example to run group merging instance.

This commit is contained in:
Ville Pietilä
2026-01-22 04:59:53 -05:00
parent a82db41c88
commit 655d133f58
2 changed files with 25 additions and 15 deletions

View File

@@ -69,7 +69,7 @@ struct CommonLayoutSettingSelector<1> final
template <>
struct CommonLayoutSettingSelector<2> final
: CommonLayoutSetting<ctl::G_NHW_C, ctl::G_K_YX_C, ctl::G_NHW_K>
: CommonLayoutSetting<ctl::NHWGC, ctl::GKYXC, ctl::NHWGK>
{
};
@@ -143,6 +143,12 @@ inline bool parse_cmd_args(int argc,
const ck::index_t num_dim_spatial = std::stoi(argv[4]);
conv_param = ck::utils::conv::parse_conv_param(
num_dim_spatial, threshold_to_catch_partial_args + 1, argv);
std::cout << "parsed conv_param: " << std::endl;
std::cout << conv_param.num_dim_spatial_ << std::endl;
std::cout << conv_param.G_ << std::endl;
std::cout << conv_param.N_ << std::endl;
std::cout << conv_param.K_ << std::endl;
std::cout << conv_param.C_ << std::endl;
}
else
{
@@ -183,7 +189,7 @@ inline HostTensorDescriptor make_input_descriptor(const ck::utils::conv::ConvPar
conv_param.input_spatial_lengths_[1] * conv_param.G_ * conv_param.C_, // hi
conv_param.G_ * conv_param.C_ // wi
},
ck::tensor_layout::convolution::GNCHW{});
ck::tensor_layout::convolution::NHWGC{});
case 3:
return HostTensorDescriptor(
@@ -239,7 +245,7 @@ inline HostTensorDescriptor make_weight_descriptor(const ck::utils::conv::ConvPa
conv_param.filter_spatial_lengths_[1] * conv_param.C_, // y
conv_param.C_ // x
},
ck::tensor_layout::convolution::GKCYX{});
ck::tensor_layout::convolution::GKYXC{});
case 3:
return HostTensorDescriptor(
{conv_param.G_,
@@ -345,7 +351,7 @@ inline HostTensorDescriptor make_output_descriptor(const ck::utils::conv::ConvPa
conv_param.output_spatial_lengths_[1] * conv_param.G_ * conv_param.K_, // ho
conv_param.G_ * conv_param.K_ // wo
},
ck::tensor_layout::convolution::GNKHW{});
ck::tensor_layout::convolution::NHWGK{});
case 3:
return HostTensorDescriptor(

View File

@@ -22,33 +22,37 @@ using DeviceConvFwdInstance =
GemmSpec, // GemmSpecialization
1, //
256, // BlockSize
128, // MPerBlock
256, // NPerBlock
16, // KPerBlock
4, // AK1
4, // BK1
64, // MPerBlock
64, // NPerBlock
32, // KPerBlock
8, // AK1
8, // BK1
16, // MPerXdl
16, // NPerXdl
4, // MXdlPerWave
8, // NXdlPerWave
2, // MXdlPerWave
2, // NXdlPerWave
S<4, 64, 1>, // ABlockTransferThreadClusterLengths_AK0_M_AK1
S<1, 0, 2>, // ABlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // ABlockTransferSrcAccessOrder
2, // ABlockTransferSrcVectorDim
4, // ABlockTransferSrcScalarPerVector
4, // ABlockTransferDstScalarPerVector_AK1
8, // ABlockTransferDstScalarPerVector_AK1
1, // ABlockLdsExtraM
S<4, 64, 1>, // BBlockTransferThreadClusterLengths_BK0_N_BK1
S<1, 0, 2>, // BBlockTransferThreadClusterArrangeOrder
S<1, 0, 2>, // BBlockTransferSrcAccessOrder
2, // BBlockTransferSrcVectorDim
4, // BBlockTransferSrcScalarPerVector
4, // BBlockTransferDstScalarPerVector_BK1
8, // BBlockTransferDstScalarPerVector_BK1
1, // BBlockLdsExtraN
1,
1,
S<1, 32, 1, 8>,
4>;
S<1, 32, 1, 4>,
4, // Vector load/store size for output tensor
InKernelDataType,
WeiKernelDataType,
ck::LoopScheduler::Default,
2>; // Number of merged groups
template <ck::index_t NDimSpatial>
using HostConvFwdInstance = ck::tensor_operation::host::ReferenceConvFwd<NDimSpatial,