[CK_BUILDER] Replace reference conv with old ck implementation (#3604)

* ck-builder: remove SPATIAL_DIM parameter from ConvTensorLayouts

This information is already in the SIGNATURE, so its pointless to pass it
separately. This streamlines the interface of those functions a bit. Also
touches up the style of those files in general.

* ck-builder: implement reference conv using old ck

The old ck implementation is more featureful and better tested.

* ck-builder: replace test_reference_execution reference with old ck

This strips out the ck-tile gpu reference implementation completely.

* ck-builder: clean up test_reference_execution

- Remove unneccesary messages
- Replace EXPECT_TRUE(true) with EXPECT_NO_THROW()

[ROCm/composable_kernel commit: 1040d9b1f5]
This commit is contained in:
Robin Voetter
2026-01-21 19:18:47 +01:00
committed by GitHub
parent 5a27de45e5
commit 2b54a86c04
24 changed files with 291 additions and 1067 deletions

View File

@@ -76,7 +76,7 @@ struct Args<SIGNATURE>
using Ops = factory::internal::ConvElementwiseOps<SIGNATURE>;
// TODO: We shouldn't need to call into an internal namespace here.
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE, SPATIAL_DIM>;
using Layouts = factory::internal::ConvTensorLayouts<SIGNATURE>;
ConvTensorLengths<SPATIAL_DIM> lengths;

View File

@@ -32,27 +32,8 @@ concept RefConvInstance = requires(Conv& conv,
const void* input,
const void* weight,
void* output,
int G,
int N,
int K,
int C,
std::vector<long_index_t> dims) {
{
conv.Run(input,
weight,
output,
G,
N,
K,
C,
dims, // input_spatial
dims, // filter_spatial
dims, // output_spatial
dims, // strides
dims, // dilations
dims // left_pads
)
};
ck::utils::conv::ConvParam param) {
{ conv.Run(input, weight, output, param) };
};
/// @brief `run()` specialization for forward convolution and the reference
@@ -84,16 +65,6 @@ std::tuple<bool, float> run(RefConvInstance<SIGNATURE> auto& conv,
// Just throw for now, but regard these as TODO items that should be resolved
// eventually.
// Right pads are not supported right now for some reason.
for(auto right_pad : param.input_right_pads_)
{
if(right_pad != 0)
{
std::cout << "TODO: Support right pad in reference conv" << std::endl;
return std::make_tuple(false, 0.0f);
}
}
if(!args.make_input_descriptor().is_packed())
{
std::cout << "TODO: Support non-packed input tensor in reference conv" << std::endl;
@@ -110,19 +81,7 @@ std::tuple<bool, float> run(RefConvInstance<SIGNATURE> auto& conv,
return std::make_tuple(false, 0.0f);
}
conv.Run(inputs.input,
inputs.weight,
outputs.output,
param.G_,
param.N_,
param.K_,
param.C_,
param.input_spatial_lengths_,
param.filter_spatial_lengths_,
param.output_spatial_lengths_,
param.conv_filter_strides_,
param.conv_filter_dilations_,
param.input_left_pads_);
conv.Run(inputs.input, inputs.weight, outputs.output, param);
return std::make_tuple(true, 0.0f);
}