[CK_BUILDER] conv bwd weight testing (#3618)

* ck-builder: restructure testing conv

In order to prepare for bwd of conv testing, this commit moves some
files and types around so that we can reuse ckt::Args for both forward
and backwards convolution.

* ck-builder: decouple fwd_ck.hpp and fwd_reference.hpp from fwd.hpp

This will allow us to more easily include fwd.hpp from backwards
definitions, which is required for initializing bwd values.

* ck-builder: fix layout of test_ckb_conv_bwd_weight_xdl_cshuffle_v3

Turns out that the supplied layout isn't actually supported...

* ck-builder: ck and reference conv integration for bwd weight

* ck-builder: ck bwd weight execution test

* ck-builder: ckt::run support for ck-tile bwd weight

* ck-builder: ck tile bwd weight execution test

* ck-builder: extra debug printing in MatchesReference

* ck-builder: make ckt::run return RunResult

This type is more convenient than std::tuple, as it will allow us to
use google test matchers with this in the future.

* ck-builder: RunResult matcher

Using EXPECT_THAT(..., SuccessfulRun()) will generate a check and a nice error
message about how and why running an algorithm failed.

* ck-builder: doc fixes

* ck-builder: add missing headers
This commit is contained in:
Robin Voetter
2026-01-26 23:50:15 +01:00
committed by GitHub
parent 8654c0628f
commit cc75948d1c
27 changed files with 939 additions and 262 deletions

View File

@@ -1,23 +1,30 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#include "ck_tile/builder/testing/conv/bwd_weight.hpp"
#include "ck_tile/builder/testing/conv/bwd_weight_ck.hpp"
#include "ck_tile/builder/testing/conv/reference.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "testing_utils.hpp"
namespace ckb = ck_tile::builder;
namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using enum ck_tile::builder::TensorLayout;
using ck_tile::test::MatchesReference;
using ck_tile::test::SuccessfulRun;
constexpr auto SIGNATURE = ckt::ConvSignature{.spatial_dim = 1,
.direction = ckb::ConvDirection::BACKWARD_WEIGHT,
.data_type = ckb::DataType::BF16,
.accumulation_data_type = ckb::DataType::FP32,
.input = {.config = {.layout = NGCW}},
.input = {.config = {.layout = GNWC}},
.weight = {.config = {.layout = GKXC}},
.output = {.config = {.layout = NGKW}}};
.output = {.config = {.layout = GNWK}}};
constexpr auto ALGORITHM =
cku::ConvAlgorithm_DeviceGroupedConvBwdWeight_Xdl_CShuffle_V3{}
@@ -30,14 +37,58 @@ constexpr auto ALGORITHM =
using Builder = ckb::ConvBuilder<SIGNATURE, ALGORITHM>;
using Instance = Builder::Instance;
using Reference = ckb::ConvBuilder<SIGNATURE, ckt::ConvAlgorithm_Reference{}>::Instance;
TEST(BwdWeight_1DBf16_CShuffle_V3, Create)
{
const auto expected_transfer_parameters = to_string(ALGORITHM);
cku::run_test<Builder>({"DeviceGroupedConvBwdWeight_Xdl_CShuffleV3",
expected_transfer_parameters,
"Filter1x1Stride1Pad0",
"NGCW,GKXC,NGKW",
"GNWC,GKXC,GNWK",
"PassThrough,PassThrough,PassThrough",
"Intrawave",
"v2"});
}
TEST(BwdWeight_1DBf16_CShuffle_V3, Execution)
{
if(!ck_tile::get_device_name().starts_with("gfx9"))
{
// Note: XDL kernel
GTEST_SKIP() << "unsupported architecture";
}
ckt::Args<SIGNATURE> args = {
.lengths =
{
.batch_size = 16,
.groups = 1,
.input_channels = 32,
.output_channels = 48,
.image = {.width = 64},
.filter = {.width = 1},
},
.filter_strides = {.width = 1},
.filter_dilation = {.width = 1},
.input_left_pad = {.width = 0},
.input_right_pad = {.width = 0},
.a_elementwise_op = {},
.b_elementwise_op = {},
.cde_elementwise_op = {},
};
auto inputs = ckt::alloc_inputs(args);
auto outputs = ckt::alloc_outputs(args);
auto reference = ckt::alloc_outputs(args);
ckt::init_inputs(args, inputs.get());
auto conv = Instance{};
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
auto ref_conv = Reference{};
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
}

View File

@@ -4,8 +4,9 @@
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_utils.hpp"
#include "utils/conv_algorithm_type_utils.hpp"
#include "ck_tile/builder/testing/conv_fwd_ck.hpp"
#include "ck_tile/builder/testing/conv_fwd_reference.hpp"
#include "ck_tile/builder/testing/conv/fwd.hpp"
#include "ck_tile/builder/testing/conv/fwd_ck.hpp"
#include "ck_tile/builder/testing/conv/reference.hpp"
#include "ck_tile/host/device_prop.hpp"
#include "testing_utils.hpp"
@@ -14,6 +15,7 @@ namespace ckt = ck_tile::builder::test;
namespace cku = ck_tile::builder::test_utils;
using ck_tile::test::MatchesReference;
using ck_tile::test::SuccessfulRun;
constexpr auto SIGNATURE =
ckt::ConvSignature{.spatial_dim = 2,
@@ -50,10 +52,11 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, Create)
"MNKPadding"});
}
TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
TEST(Fwd2DFp16_CShufV3_GNHWC, Execution)
{
if(!ck_tile::get_device_name().starts_with("gfx9"))
{
// Note: XDL kernel
GTEST_SKIP() << "unsupported architecture";
}
@@ -91,10 +94,10 @@ TEST(Fwd2DFp16_CShufV3_GNHWC, EndToEnd)
ckt::init_inputs(args, inputs.get());
auto conv = Instance{};
ckt::run(conv, args, inputs.get(), outputs.get());
EXPECT_THAT(ckt::run(conv, args, inputs.get(), outputs.get()), SuccessfulRun());
auto ref_conv = Reference{};
ckt::run(ref_conv, args, inputs.get(), reference.get());
EXPECT_THAT(ckt::run(ref_conv, args, inputs.get(), reference.get()), SuccessfulRun());
EXPECT_THAT(outputs.get(), MatchesReference(args, reference.get()));
}