mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-07-01 12:17:00 +00:00
Improve ckb fwd conv instance tests.
This commit is contained in:
@@ -43,7 +43,8 @@ add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp)
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp
|
||||
)
|
||||
|
||||
function(add_ck_factory_test test_name)
|
||||
add_ck_builder_test(${test_name} ${ARGN})
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -28,13 +28,12 @@ TEST(FwdConvInstances,
|
||||
.block_gemm = BlockGemmDesc_v2_intrawave};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto& asserts = InstanceNameAsserts{}
|
||||
.StartsWith("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3")
|
||||
.Contains("Filter1x1Stride1Pad0")
|
||||
.Contains("BlkGemmPipelineVersion: v2");
|
||||
|
||||
run_test<Builder>(asserts);
|
||||
run_test<Builder>({
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Filter1x1Stride1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v2"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -29,12 +29,10 @@ TEST(FwdConvInstances,
|
||||
.loop_scheduler = LoopScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
|
||||
const auto& asserts = InstanceNameAsserts{}
|
||||
.StartsWith("DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle")
|
||||
.Contains("Default");
|
||||
|
||||
run_test<Builder>(asserts);
|
||||
run_test<Builder>({
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle",
|
||||
"64, 64, 32, 32",
|
||||
"Default"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -28,7 +28,10 @@ TEST(FwdConvInstances,
|
||||
.loop_scheduler = LoopScheduler::DEFAULT};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>();
|
||||
run_test<Builder>({
|
||||
"DeviceGroupedConvFwdMultipleD_Wmma_CShuffle",
|
||||
"128, 64, 64, 64",
|
||||
"Default"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace{
|
||||
|
||||
@@ -27,7 +27,12 @@ TEST(FwdConvInstances,
|
||||
.block_gemm = BlockGemmDesc_v1_intrawave};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>();
|
||||
run_test<Builder>({
|
||||
"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Default",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v1"});
|
||||
}
|
||||
|
||||
// 2D BF16 NHWGC (channels-last) with Pipeline V5 and FILTER_3x3
|
||||
@@ -52,7 +57,9 @@ TEST(FwdConvInstances,
|
||||
.block_gemm = BlockGemmDesc_v5_intrawave};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>();
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"Filter3x3",
|
||||
"BlkGemmPipelineVersion: v5"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -26,7 +26,11 @@ TEST(FwdConvInstances,
|
||||
.block_gemm = BlockGemmDesc_v3_intrawave};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>();
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v3"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -26,7 +26,11 @@ TEST(FwdConvInstances,
|
||||
.block_gemm = BlockGemmDesc_v4_intrawave};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>();
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 128, 128, 32",
|
||||
"Filter1x1Stride1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v4"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -27,7 +27,11 @@ TEST(FwdConvInstances,
|
||||
.block_gemm = BlockGemmDesc_v3_intrawave};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>();
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Default",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v3"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -27,7 +27,11 @@ TEST(FwdConvInstances,
|
||||
.block_gemm = BlockGemmDesc_v4_intrawave};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>();
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 128, 128, 32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v4"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include "utils/ckb_conv_test_configs.hpp"
|
||||
#include "utils/ckb_conv_test_common.hpp"
|
||||
#include "utils/ckb_conv_test_utils.hpp"
|
||||
|
||||
namespace {
|
||||
|
||||
@@ -27,7 +27,11 @@ TEST(FwdConvInstances,
|
||||
.block_gemm = BlockGemmDesc_v1_intrawave};
|
||||
|
||||
using Builder = ConvBuilder<FwdConvSignature, FwdConvAlgorithm>;
|
||||
run_test<Builder>();
|
||||
run_test<Builder>({"DeviceGroupedConvFwdMultipleABD_Xdl_CShuffle_V3",
|
||||
"256, 256, 256, 32",
|
||||
"Filter1x1Pad0",
|
||||
"BlkGemmPipelineScheduler: Intrawave",
|
||||
"BlkGemmPipelineVersion: v1"});
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
class InstanceNameAsserts
|
||||
{
|
||||
public:
|
||||
InstanceNameAsserts& StartsWith(const char* prefix)
|
||||
{
|
||||
prefixes_.push_back(std::string(prefix));
|
||||
return *this;
|
||||
}
|
||||
|
||||
InstanceNameAsserts& Contains(const char* substring)
|
||||
{
|
||||
substrings_.push_back(std::string(substring));
|
||||
return *this;
|
||||
}
|
||||
|
||||
void Check(const std::string& kernel_string) const
|
||||
{
|
||||
for (const auto& prefix : prefixes_)
|
||||
{
|
||||
EXPECT_THAT(kernel_string, ::testing::StartsWith(prefix));
|
||||
}
|
||||
for (const auto& substr : substrings_)
|
||||
{
|
||||
EXPECT_THAT(kernel_string, ::testing::HasSubstr(substr));
|
||||
}
|
||||
}
|
||||
private:
|
||||
std::vector<std::string> prefixes_;
|
||||
std::vector<std::string> substrings_;
|
||||
};
|
||||
|
||||
// Common test implementation
|
||||
template <typename Builder>
|
||||
constexpr void run_test(const InstanceNameAsserts& asserts)
|
||||
{
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
|
||||
asserts.Check(kernel_string);
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
37
experimental/builder/test/utils/ckb_conv_test_utils.hpp
Normal file
37
experimental/builder/test/utils/ckb_conv_test_utils.hpp
Normal file
@@ -0,0 +1,37 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include "impl/conv_algorithm_types.hpp"
|
||||
#include "impl/conv_signature_types.hpp"
|
||||
#include "ck_tile/builder/conv_builder.hpp"
|
||||
|
||||
namespace ck_tile::builder::test_utils {
|
||||
|
||||
using namespace ck_tile::builder;
|
||||
using namespace test;
|
||||
|
||||
|
||||
// Common test implementation
|
||||
template <typename Builder>
|
||||
constexpr void run_test(const std::vector<std::string>& kernel_instance_components)
|
||||
{
|
||||
auto instance = typename Builder::Instance{};
|
||||
|
||||
const auto kernel_string = instance.GetTypeString();
|
||||
std::cout << "Generated kernel: " << kernel_string << std::endl;
|
||||
EXPECT_GT(kernel_string.size(), 0);
|
||||
|
||||
const auto invoker_ptr = instance.MakeInvokerPointer();
|
||||
EXPECT_NE(invoker_ptr, nullptr);
|
||||
|
||||
for (const auto& component : kernel_instance_components)
|
||||
{
|
||||
EXPECT_THAT(kernel_string, ::testing::HasSubstr(component));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace ck_tile::builder::test_utils
|
||||
Reference in New Issue
Block a user