Improve ckb fwd conv instance tests.

This commit is contained in:
Ville Pietilä
2025-11-04 08:58:25 +00:00
parent 1893079fdc
commit 0ac48abe61
12 changed files with 96 additions and 95 deletions

View File

@@ -43,7 +43,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_2d_fp32.cpp
conv/test_ckb_conv_fwd_3d_bf16.cpp
conv/test_ckb_conv_fwd_3d_fp16.cpp
conv/test_ckb_conv_fwd_3d_fp32.cpp)
conv/test_ckb_conv_fwd_3d_fp32.cpp
)
function(add_ck_factory_test test_name)
add_ck_builder_test(${test_name} ${ARGN})

View File

@@ -1,5 +1,5 @@
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_common.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
@@ -28,13 +28,12 @@ TEST(FwdConvInstances,
.block_gemm = BlockGemmDesc_v2_intrawave};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
const auto& asserts = InstanceNameAsserts{}
.StartsWith("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3")
.Contains("Filter1x1Stride1Pad0")
.Contains("BlkGemmPipelineVersion: v2");
run_test<Builder>(asserts);
run_test<Builder>({
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 256, 256, 32",
"Filter1x1Stride1Pad0",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v2"});
}
} // namespace

View File

@@ -1,5 +1,5 @@
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_common.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
@@ -29,12 +29,10 @@ TEST(FwdConvInstances,
.loop_scheduler = LoopScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
const auto& asserts = InstanceNameAsserts{}
.StartsWith("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle")
.Contains("Default");
run_test<Builder>(asserts);
run_test<Builder>({
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
"64, 64, 32, 32",
"Default"});
}
} // namespace

View File

@@ -1,5 +1,5 @@
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_common.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
@@ -28,7 +28,10 @@ TEST(FwdConvInstances,
.loop_scheduler = LoopScheduler::DEFAULT};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>();
run_test<Builder>({
"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle",
"128, 64, 64, 64",
"Default"});
}
} // namespace

View File

@@ -1,5 +1,5 @@
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_common.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace{
@@ -27,7 +27,12 @@ TEST(FwdConvInstances,
.block_gemm = BlockGemmDesc_v1_intrawave};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>();
run_test<Builder>({
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 256, 256, 32",
"Default",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v1"});
}
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
@@ -52,7 +57,9 @@ TEST(FwdConvInstances,
.block_gemm = BlockGemmDesc_v5_intrawave};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>();
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"Filter3x3",
"BlkGemmPipelineVersion: v5"});
}
} // namespace

View File

@@ -1,5 +1,5 @@
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_common.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
@@ -26,7 +26,11 @@ TEST(FwdConvInstances,
.block_gemm = BlockGemmDesc_v3_intrawave};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>();
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 256, 256, 32",
"Filter1x1Pad0",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v3"});
}
} // namespace

View File

@@ -1,5 +1,5 @@
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_common.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
@@ -26,7 +26,11 @@ TEST(FwdConvInstances,
.block_gemm = BlockGemmDesc_v4_intrawave};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>();
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 128, 128, 32",
"Filter1x1Stride1Pad0",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v4"});
}
} // namespace

View File

@@ -1,5 +1,5 @@
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_common.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
@@ -27,7 +27,11 @@ TEST(FwdConvInstances,
.block_gemm = BlockGemmDesc_v3_intrawave};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>();
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 256, 256, 32",
"Default",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v3"});
}
} // namespace

View File

@@ -1,5 +1,5 @@
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_common.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
@@ -27,7 +27,11 @@ TEST(FwdConvInstances,
.block_gemm = BlockGemmDesc_v4_intrawave};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>();
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 128, 128, 32",
"Filter1x1Pad0",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v4"});
}
} // namespace

View File

@@ -1,5 +1,5 @@
#include "utils/ckb_conv_test_configs.hpp"
#include "utils/ckb_conv_test_common.hpp"
#include "utils/ckb_conv_test_utils.hpp"
namespace {
@@ -27,7 +27,11 @@ TEST(FwdConvInstances,
.block_gemm = BlockGemmDesc_v1_intrawave};
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
run_test<Builder>();
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
"256, 256, 256, 32",
"Filter1x1Pad0",
"BlkGemmPipelineScheduler: Intrawave",
"BlkGemmPipelineVersion: v1"});
}
} // namespace

View File

@@ -1,64 +0,0 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include "impl/conv_algorithm_types.hpp"
#include "impl/conv_signature_types.hpp"
#include "ck_tile/builder/conv_builder.hpp"
namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
class InstanceNameAsserts
{
public:
InstanceNameAsserts& StartsWith(const char* prefix)
{
prefixes_.push_back(std::string(prefix));
return *this;
}
InstanceNameAsserts& Contains(const char* substring)
{
substrings_.push_back(std::string(substring));
return *this;
}
void Check(const std::string& kernel_string) const
{
for (const auto& prefix : prefixes_)
{
EXPECT_THAT(kernel_string, ::testing::StartsWith(prefix));
}
for (const auto& substr : substrings_)
{
EXPECT_THAT(kernel_string, ::testing::HasSubstr(substr));
}
}
private:
std::vector<std::string> prefixes_;
std::vector<std::string> substrings_;
};
// Common test implementation
template <typename Builder>
constexpr void run_test(const InstanceNameAsserts& asserts)
{
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);
const auto invoker_ptr = instance.MakeInvokerPointer();
EXPECT_NE(invoker_ptr, nullptr);
asserts.Check(kernel_string);
}
} // namespace ck_tile::builder::test_utils

View File

@@ -0,0 +1,37 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include "impl/conv_algorithm_types.hpp"
#include "impl/conv_signature_types.hpp"
#include "ck_tile/builder/conv_builder.hpp"
namespace ck_tile::builder::test_utils {
using namespace ck_tile::builder;
using namespace test;
// Common test implementation
template <typename Builder>
constexpr void run_test(const std::vector<std::string>& kernel_instance_components)
{
auto instance = typename Builder::Instance{};
const auto kernel_string = instance.GetTypeString();
std::cout << "Generated kernel: " << kernel_string << std::endl;
EXPECT_GT(kernel_string.size(), 0);
const auto invoker_ptr = instance.MakeInvokerPointer();
EXPECT_NE(invoker_ptr, nullptr);
for (const auto& component : kernel_instance_components)
{
EXPECT_THAT(kernel_string, ::testing::HasSubstr(component));
}
}
} // namespace ck_tile::builder::test_utils