mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-04-19 22:39:03 +00:00
[CK_BUILDER] Test and fix instance traits utils. (#3096)
* Refactor instance_traits_util and add unit tests tests * Address reviewer comments. Just adds some TODOs to indicate deprecated layouts in our reflection. Our strategy is to leave the reflection code broad (covering deprecated features), but keep the builder concepts narrow. Once we've removed deprecated features from all instances, we can remove them from reflection. Also add a comment to the cmake to explain the unit test target test_conv_builder. * Addressed more reviewer comments. * Remove duplicate PassThrough::name Accidentally added this field to the end of the struct, too. The `name` field should be a the start of the struct for consistency.
This commit is contained in:
@@ -23,19 +23,12 @@
|
||||
|
||||
namespace ck_tile::reflect::detail {
|
||||
|
||||
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
|
||||
template <typename Seq>
|
||||
struct SequenceToArray;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
struct SequenceToArray<ck::Sequence<Is...>>
|
||||
{
|
||||
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
|
||||
};
|
||||
|
||||
// Convert data types to string names
|
||||
// Implementation detail for type name mapping
|
||||
// This is the single source of truth for supported data types that
|
||||
// returns an empty string to indicate an unsupported type.
|
||||
namespace impl {
|
||||
template <typename T>
|
||||
consteval std::string_view type_name()
|
||||
consteval std::string_view type_name_impl()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck::half_t>)
|
||||
return "fp16";
|
||||
@@ -56,22 +49,38 @@ consteval std::string_view type_name()
|
||||
else if constexpr(std::is_same_v<T, ck::bf8_t>)
|
||||
return "bf8";
|
||||
else
|
||||
static_assert(false, "unknown_type");
|
||||
return std::string_view{}; // Return empty for supported types
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
// Convert data types to string names
|
||||
// Fails at compile time for unsupported types
|
||||
template <typename T>
|
||||
consteval std::string_view type_name()
|
||||
{
|
||||
constexpr auto name = impl::type_name_impl<T>();
|
||||
static_assert(!name.empty(), "Unsupported data type");
|
||||
return name;
|
||||
}
|
||||
|
||||
// Convert layout types to string names
|
||||
// Concept that checks if a type is a valid data type
|
||||
// Uses the impl directly to avoid triggering static_assert during concept evaluation
|
||||
template <typename T>
|
||||
concept IsDataType = !impl::type_name_impl<T>().empty();
|
||||
|
||||
// Concept that checks valid layout types
|
||||
template <typename T>
|
||||
concept IsLayoutType = (std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> ||
|
||||
std::is_base_of_v<ck::tensor_layout::BaseTensorLayout, T>) &&
|
||||
requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
};
|
||||
|
||||
// Convert layout types to string names
|
||||
template <IsLayoutType T>
|
||||
constexpr std::string_view layout_name()
|
||||
{
|
||||
if constexpr((std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> ||
|
||||
std::is_base_of_v<ck::tensor_layout::BaseTensorLayout, T>) &&
|
||||
requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
})
|
||||
return T::name;
|
||||
else
|
||||
static_assert(false,
|
||||
"Layout type must derive from BaseTensorLayout and have name attribute");
|
||||
return T::name;
|
||||
}
|
||||
|
||||
// Convert element-wise operation types to string names
|
||||
@@ -90,64 +99,64 @@ constexpr std::string_view elementwise_op_name()
|
||||
constexpr std::string_view
|
||||
conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecialization spec)
|
||||
{
|
||||
using ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case ConvolutionForwardSpecialization::Default: return "Default";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3";
|
||||
case ConvolutionForwardSpecialization::OddC: return "OddC";
|
||||
case Default: return "Default";
|
||||
case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case Filter3x3: return "Filter3x3";
|
||||
case OddC: return "OddC";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert GemmSpecialization enum to string
|
||||
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
|
||||
{
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case GemmSpecialization::Default: return "Default";
|
||||
case GemmSpecialization::MPadding: return "MPadding";
|
||||
case GemmSpecialization::NPadding: return "NPadding";
|
||||
case GemmSpecialization::KPadding: return "KPadding";
|
||||
case GemmSpecialization::MNPadding: return "MNPadding";
|
||||
case GemmSpecialization::MKPadding: return "MKPadding";
|
||||
case GemmSpecialization::NKPadding: return "NKPadding";
|
||||
case GemmSpecialization::MNKPadding: return "MNKPadding";
|
||||
case GemmSpecialization::OPadding: return "OPadding";
|
||||
case GemmSpecialization::MOPadding: return "MOPadding";
|
||||
case GemmSpecialization::NOPadding: return "NOPadding";
|
||||
case GemmSpecialization::KOPadding: return "KOPadding";
|
||||
case GemmSpecialization::MNOPadding: return "MNOPadding";
|
||||
case GemmSpecialization::MKOPadding: return "MKOPadding";
|
||||
case GemmSpecialization::NKOPadding: return "NKOPadding";
|
||||
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
|
||||
case Default: return "Default";
|
||||
case MPadding: return "MPadding";
|
||||
case NPadding: return "NPadding";
|
||||
case KPadding: return "KPadding";
|
||||
case MNPadding: return "MNPadding";
|
||||
case MKPadding: return "MKPadding";
|
||||
case NKPadding: return "NKPadding";
|
||||
case MNKPadding: return "MNKPadding";
|
||||
case OPadding: return "OPadding";
|
||||
case MOPadding: return "MOPadding";
|
||||
case NOPadding: return "NOPadding";
|
||||
case KOPadding: return "KOPadding";
|
||||
case MNOPadding: return "MNOPadding";
|
||||
case MKOPadding: return "MKOPadding";
|
||||
case NKOPadding: return "NKOPadding";
|
||||
case MNKOPadding: return "MNKOPadding";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BlockGemmPipelineScheduler enum to string
|
||||
constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineScheduler sched)
|
||||
{
|
||||
using ck::BlockGemmPipelineScheduler;
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case BlockGemmPipelineScheduler::Intrawave: return "Intrawave";
|
||||
case BlockGemmPipelineScheduler::Interwave: return "Interwave";
|
||||
case Intrawave: return "Intrawave";
|
||||
case Interwave: return "Interwave";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BlockGemmPipelineVersion enum to string
|
||||
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
|
||||
{
|
||||
using ck::BlockGemmPipelineVersion;
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
switch(ver)
|
||||
{
|
||||
case BlockGemmPipelineVersion::v1: return "v1";
|
||||
case BlockGemmPipelineVersion::v2: return "v2";
|
||||
case BlockGemmPipelineVersion::v3: return "v3";
|
||||
case BlockGemmPipelineVersion::v4: return "v4";
|
||||
case BlockGemmPipelineVersion::v5: return "v5";
|
||||
case v1: return "v1";
|
||||
case v2: return "v2";
|
||||
case v3: return "v3";
|
||||
case v4: return "v4";
|
||||
case v5: return "v5";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,12 +176,138 @@ inline std::string array_to_string(const std::array<T, N>& arr)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Handle ck::Tuple (empty tuple for DsLayout/DsDataType)
|
||||
template <typename T>
|
||||
constexpr std::string_view tuple_name()
|
||||
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
|
||||
template <typename Seq>
|
||||
struct SequenceToArray;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
struct SequenceToArray<ck::Sequence<Is...>>
|
||||
{
|
||||
// For now, just check if it's an empty tuple
|
||||
return "EmptyTuple";
|
||||
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
// Generic helper to build list-like strings (Tuple, Seq, etc.)
|
||||
//
|
||||
// Example output: "Seq(1,2,3)"
|
||||
//
|
||||
// prefix: The list-like container name (e.g. "Tuple" or "Seq")
|
||||
// converter_fn: A callable that converts each element to a string representation
|
||||
// For types: converter_fn should be a template lambda like []<typename U>() { return
|
||||
// type_name<U>(); } For values: converter_fn should be a regular lambda like [](auto value) {
|
||||
// return std::to_string(value); }
|
||||
template <typename ConverterFn, typename... Elements>
|
||||
constexpr std::string build_list_string(std::string_view prefix, const ConverterFn& converter_fn)
|
||||
{
|
||||
if constexpr(sizeof...(Elements) == 0)
|
||||
{
|
||||
return std::string(prefix) + "()";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string result = std::string(prefix) + "(";
|
||||
std::size_t index = 0;
|
||||
((result +=
|
||||
(index++ > 0 ? "," : "") + std::string(converter_fn.template operator()<Elements>())),
|
||||
...);
|
||||
result += ")";
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Overload for value-based lists (sequences)
|
||||
template <typename ConverterFn, auto... Values>
|
||||
constexpr std::string build_list_string_values(std::string_view prefix,
|
||||
const ConverterFn& converter_fn)
|
||||
{
|
||||
if constexpr(sizeof...(Values) == 0)
|
||||
{
|
||||
return std::string(prefix) + "()";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string result = std::string(prefix) + "(";
|
||||
std::size_t index = 0;
|
||||
((result += (index++ > 0 ? "," : "") + converter_fn(Values)), ...);
|
||||
result += ")";
|
||||
return result;
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// Convert ck::Sequence to string representation
|
||||
// Converts a ck::Sequence type to a string in the format "Seq(v1,v2,...,vn)"
|
||||
// where each value is converted using std::to_string.
|
||||
//
|
||||
// Template parameter:
|
||||
// T: Must be a ck::Sequence<...> type
|
||||
//
|
||||
// Constraints:
|
||||
// - Sequence elements must support std::to_string (typically ck::index_t)
|
||||
//
|
||||
// Examples:
|
||||
// sequence_name<ck::Sequence<>>() returns "Seq()"
|
||||
// sequence_name<ck::Sequence<42>>() returns "Seq(42)"
|
||||
// sequence_name<ck::Sequence<1,2,3>>() returns "Seq(1,2,3)"
|
||||
// sequence_name<ck::Sequence<256,128,64>>() returns "Seq(256,128,64)"
|
||||
template <typename T>
|
||||
requires requires { []<ck::index_t... Is>(ck::Sequence<Is...>*) {}(static_cast<T*>(nullptr)); }
|
||||
constexpr std::string sequence_name()
|
||||
{
|
||||
return []<ck::index_t... Is>(ck::Sequence<Is...>*) constexpr {
|
||||
auto to_string_fn = [](auto value) { return std::to_string(value); };
|
||||
return detail::build_list_string_values<decltype(to_string_fn), Is...>("Seq", to_string_fn);
|
||||
}(static_cast<T*>(nullptr));
|
||||
}
|
||||
|
||||
// Convert ck::Tuple to string representation
|
||||
// Converts a ck::Tuple type to a string in the format "Tuple(e1,e2,...,en)"
|
||||
// where each element is converted based on its type (layout names or data type names).
|
||||
//
|
||||
// Template parameter:
|
||||
// T: Must be a ck::Tuple<...> type
|
||||
//
|
||||
// Constraints:
|
||||
// - Empty tuples are supported and return "EmptyTuple"
|
||||
// - All tuple elements must be homogeneous: either all layouts (IsLayoutType) or all data types
|
||||
// (IsDataType)
|
||||
// - Mixed layouts and data types in the same tuple will cause a compile-time error
|
||||
//
|
||||
// Examples:
|
||||
// tuple_name<ck::Tuple<>>() returns "EmptyTuple"
|
||||
// tuple_name<ck::Tuple<ck::tensor_layout::gemm::RowMajor>>() returns "Tuple(RowMajor)"
|
||||
// tuple_name<ck::Tuple<NCHW,NHWC>>() returns "Tuple(NCHW,NHWC)"
|
||||
// tuple_name<ck::Tuple<ck::half_t>>() returns "Tuple(fp16)"
|
||||
// tuple_name<ck::Tuple<ck::half_t,float,double>>() returns "Tuple(fp16,fp32,fp64)"
|
||||
template <typename T>
|
||||
requires requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); }
|
||||
constexpr std::string tuple_name()
|
||||
{
|
||||
return []<typename... Ts>(ck::Tuple<Ts...>*) constexpr {
|
||||
if constexpr(sizeof...(Ts) == 0)
|
||||
{
|
||||
return std::string("EmptyTuple");
|
||||
}
|
||||
else if constexpr((IsLayoutType<Ts> && ...))
|
||||
{
|
||||
// Lambda wrapper for layout_name
|
||||
auto layout_name_fn = []<typename U>() { return layout_name<U>(); };
|
||||
return detail::build_list_string<decltype(layout_name_fn), Ts...>("Tuple",
|
||||
layout_name_fn);
|
||||
}
|
||||
else if constexpr((IsDataType<Ts> && ...))
|
||||
{
|
||||
// Lambda wrapper for type_name
|
||||
auto type_name_fn = []<typename U>() { return type_name<U>(); };
|
||||
return detail::build_list_string<decltype(type_name_fn), Ts...>("Tuple", type_name_fn);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert((IsLayoutType<Ts> && ...) || (IsDataType<Ts> && ...),
|
||||
"Tuple elements must be all layouts or all data types, not mixed");
|
||||
return std::string{}; // unreachable
|
||||
}
|
||||
}(static_cast<T*>(nullptr));
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::detail
|
||||
|
||||
@@ -16,20 +16,13 @@ function(add_ck_builder_test test_name)
|
||||
target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock)
|
||||
endfunction()
|
||||
|
||||
# The test_conv_builder target has all the unit tests (each test should run < 10 ms)
|
||||
add_ck_builder_test(test_conv_builder
|
||||
test_conv_builder.cpp
|
||||
test_instance_traits.cpp
|
||||
test_instance_traits_util.cpp
|
||||
testing_utils.cpp)
|
||||
|
||||
# Testing the virtual GetInstanceString methods requires kernel compilation.
|
||||
add_ck_builder_test(test_get_instance_string
|
||||
test_get_instance_string.cpp)
|
||||
|
||||
add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp)
|
||||
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp)
|
||||
263
experimental/builder/test/test_instance_traits_util.cpp
Normal file
263
experimental/builder/test/test_instance_traits_util.cpp
Normal file
@@ -0,0 +1,263 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits_util.hpp>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
|
||||
namespace ck_tile::reflect::detail {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::IsEmpty;
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceToArrayReturnsEmptyArrayForEmptySequence)
|
||||
{
|
||||
EXPECT_THAT(SequenceToArray<ck::Sequence<>>::value, IsEmpty());
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithSingleElement)
|
||||
{
|
||||
EXPECT_THAT(SequenceToArray<ck::Sequence<42>>::value, ElementsAre(42));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithMultipleElements)
|
||||
{
|
||||
EXPECT_THAT((SequenceToArray<ck::Sequence<1, 2, 3, 4, 5>>::value), ElementsAre(1, 2, 3, 4, 5));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TypeNameReturnsCorrectStrings)
|
||||
{
|
||||
EXPECT_THAT((std::vector<std::string_view>{type_name<ck::half_t>(),
|
||||
type_name<float>(),
|
||||
type_name<double>(),
|
||||
type_name<int8_t>(),
|
||||
type_name<int32_t>(),
|
||||
type_name<ck::bhalf_t>(),
|
||||
type_name<ck::f8_t>(),
|
||||
type_name<ck::bf8_t>()}),
|
||||
ElementsAre("fp16", "fp32", "fp64", "s8", "s32", "bf16", "fp8", "bf8"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForGemmLayouts)
|
||||
{
|
||||
namespace gemm = ck::tensor_layout::gemm;
|
||||
EXPECT_THAT((std::vector<std::string_view>{layout_name<gemm::RowMajor>(),
|
||||
layout_name<gemm::ColumnMajor>(),
|
||||
layout_name<gemm::MFMA>()}),
|
||||
ElementsAre("RowMajor", "ColumnMajor", "MFMA"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForConvLayouts)
|
||||
{
|
||||
namespace conv = ck::tensor_layout::convolution;
|
||||
EXPECT_THAT((std::vector<std::string_view>{
|
||||
// Input tensor layouts
|
||||
// TODO(deprecated): Remove non-grouped layouts once instances are removed.
|
||||
layout_name<conv::NCHW>(),
|
||||
layout_name<conv::NHWC>(),
|
||||
layout_name<conv::NCDHW>(),
|
||||
layout_name<conv::NDHWC>(),
|
||||
// Grouped input layouts
|
||||
layout_name<conv::GNCHW>(),
|
||||
layout_name<conv::GNHWC>(),
|
||||
// Weight tensor layouts
|
||||
layout_name<conv::KCYX>(),
|
||||
layout_name<conv::KYXC>(),
|
||||
layout_name<conv::GKCYX>(),
|
||||
layout_name<conv::GKYXC>(),
|
||||
// Output tensor layouts
|
||||
layout_name<conv::NKHW>(),
|
||||
layout_name<conv::NHWK>(),
|
||||
layout_name<conv::GNKHW>(),
|
||||
layout_name<conv::GNHWK>(),
|
||||
// Strided layouts
|
||||
// TODO(deprecated): Remove strided layouts once instances are removed.
|
||||
layout_name<conv::G_NHW_C>(),
|
||||
layout_name<conv::G_K_YX_C>(),
|
||||
layout_name<conv::G_NHW_K>(),
|
||||
// Bias layouts
|
||||
layout_name<conv::G_C>(),
|
||||
layout_name<conv::G_K>()}),
|
||||
ElementsAre("NCHW",
|
||||
"NHWC",
|
||||
"NCDHW",
|
||||
"NDHWC",
|
||||
"GNCHW",
|
||||
"GNHWC",
|
||||
"KCYX",
|
||||
"KYXC",
|
||||
"GKCYX",
|
||||
"GKYXC",
|
||||
"NKHW",
|
||||
"NHWK",
|
||||
"GNKHW",
|
||||
"GNHWK",
|
||||
"G_NHW_C",
|
||||
"G_K_YX_C",
|
||||
"G_NHW_K",
|
||||
"G_C",
|
||||
"G_K"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, ElementwiseOpNameReturnsCorrectStrings)
|
||||
{
|
||||
namespace element_wise = ck::tensor_operation::element_wise;
|
||||
EXPECT_THAT((std::vector<std::string_view>{
|
||||
elementwise_op_name<element_wise::PassThrough>(),
|
||||
elementwise_op_name<element_wise::Scale>(),
|
||||
elementwise_op_name<element_wise::Bilinear>(),
|
||||
elementwise_op_name<element_wise::Add>(),
|
||||
elementwise_op_name<element_wise::AddRelu>(),
|
||||
elementwise_op_name<element_wise::Relu>(),
|
||||
elementwise_op_name<element_wise::BiasNormalizeInInferClamp>(),
|
||||
elementwise_op_name<element_wise::Clamp>(),
|
||||
elementwise_op_name<element_wise::AddClamp>()}),
|
||||
ElementsAre("PassThrough",
|
||||
"Scale",
|
||||
"Bilinear",
|
||||
"Add",
|
||||
"AddRelu",
|
||||
"Relu",
|
||||
"BiasNormalizeInInferClamp",
|
||||
"Clamp",
|
||||
"AddClamp"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, ConvFwdSpecNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
EXPECT_THAT(
|
||||
(std::vector<std::string_view>{conv_fwd_spec_name(Default),
|
||||
conv_fwd_spec_name(Filter1x1Stride1Pad0),
|
||||
conv_fwd_spec_name(Filter1x1Pad0),
|
||||
conv_fwd_spec_name(Filter3x3),
|
||||
conv_fwd_spec_name(OddC)}),
|
||||
ElementsAre("Default", "Filter1x1Stride1Pad0", "Filter1x1Pad0", "Filter3x3", "OddC"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, GemmSpecNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
EXPECT_THAT((std::vector<std::string_view>{gemm_spec_name(Default),
|
||||
gemm_spec_name(MPadding),
|
||||
gemm_spec_name(NPadding),
|
||||
gemm_spec_name(KPadding),
|
||||
gemm_spec_name(MNPadding),
|
||||
gemm_spec_name(MKPadding),
|
||||
gemm_spec_name(NKPadding),
|
||||
gemm_spec_name(MNKPadding),
|
||||
gemm_spec_name(OPadding),
|
||||
gemm_spec_name(MOPadding),
|
||||
gemm_spec_name(NOPadding),
|
||||
gemm_spec_name(KOPadding),
|
||||
gemm_spec_name(MNOPadding),
|
||||
gemm_spec_name(MKOPadding),
|
||||
gemm_spec_name(NKOPadding),
|
||||
gemm_spec_name(MNKOPadding)}),
|
||||
ElementsAre("Default",
|
||||
"MPadding",
|
||||
"NPadding",
|
||||
"KPadding",
|
||||
"MNPadding",
|
||||
"MKPadding",
|
||||
"NKPadding",
|
||||
"MNKPadding",
|
||||
"OPadding",
|
||||
"MOPadding",
|
||||
"NOPadding",
|
||||
"KOPadding",
|
||||
"MNOPadding",
|
||||
"MKOPadding",
|
||||
"NKOPadding",
|
||||
"MNKOPadding"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, PipelineSchedulerNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
EXPECT_THAT((std::vector<std::string_view>{pipeline_scheduler_name(Intrawave),
|
||||
pipeline_scheduler_name(Interwave)}),
|
||||
ElementsAre("Intrawave", "Interwave"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
EXPECT_THAT((std::vector<std::string_view>{pipeline_version_name(v1),
|
||||
pipeline_version_name(v2),
|
||||
pipeline_version_name(v3),
|
||||
pipeline_version_name(v4),
|
||||
pipeline_version_name(v5)}),
|
||||
ElementsAre("v1", "v2", "v3", "v4", "v5"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<>>(), "EmptyTuple");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleLayout)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW>>(), "Tuple(NCHW)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoLayouts)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW,
|
||||
ck::tensor_layout::convolution::NHWC>>()),
|
||||
"Tuple(NCHW,NHWC)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeLayouts)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::NKHW>>()),
|
||||
"Tuple(NCHW,NHWC,NKHW)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleDataType)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<ck::half_t>>(), "Tuple(fp16)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoDataTypes)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::half_t, float>>()), "Tuple(fp16,fp32)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeDataTypes)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::half_t, float, double>>()), "Tuple(fp16,fp32,fp64)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForEmptySequence)
|
||||
{
|
||||
EXPECT_EQ(sequence_name<ck::Sequence<>>(), "Seq()");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForSingleValueSequence)
|
||||
{
|
||||
EXPECT_EQ(sequence_name<ck::Sequence<42>>(), "Seq(42)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForTwoValueSequence)
|
||||
{
|
||||
EXPECT_EQ((sequence_name<ck::Sequence<1, 2>>()), "Seq(1,2)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForMultipleValueSequence)
|
||||
{
|
||||
EXPECT_EQ((sequence_name<ck::Sequence<256, 128, 64, 32, 16>>()), "Seq(256,128,64,32,16)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace ck_tile::reflect::detail
|
||||
@@ -12,6 +12,8 @@ namespace element_wise {
|
||||
|
||||
struct Add
|
||||
{
|
||||
static constexpr const char* name = "Add";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
@@ -279,6 +281,8 @@ struct Subtract
|
||||
|
||||
struct Bilinear
|
||||
{
|
||||
static constexpr const char* name = "Bilinear";
|
||||
|
||||
Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
@@ -353,6 +357,8 @@ struct Bilinear
|
||||
|
||||
struct AddClamp
|
||||
{
|
||||
static constexpr const char* name = "AddClamp";
|
||||
|
||||
AddClamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
|
||||
: floor_(floor), ceil_(ceil){};
|
||||
|
||||
@@ -442,6 +448,8 @@ struct AddClamp
|
||||
|
||||
struct AddRelu
|
||||
{
|
||||
static constexpr const char* name = "AddRelu";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
|
||||
@@ -565,6 +565,8 @@ struct NormalizeInInfer
|
||||
// used by Conv+Bias+BatchNorm+Clamp inference
|
||||
struct BiasNormalizeInInferClamp
|
||||
{
|
||||
static constexpr const char* name = "BiasNormalizeInInferClamp";
|
||||
|
||||
BiasNormalizeInInferClamp(float floor = 0.f,
|
||||
float ceil = NumericLimits<float>::Max(),
|
||||
float epsilon = 1e-4)
|
||||
|
||||
@@ -332,6 +332,8 @@ struct PassThroughPack2
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
static constexpr const char* name = "PassThrough";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -552,8 +554,6 @@ struct PassThrough
|
||||
{
|
||||
y = type_convert<bf8_t>(x);
|
||||
}
|
||||
|
||||
static constexpr const char* name = "PassThrough";
|
||||
};
|
||||
|
||||
struct UnaryConvert
|
||||
@@ -620,6 +620,8 @@ struct ConvertF8RNE
|
||||
|
||||
struct Scale
|
||||
{
|
||||
static constexpr const char* name = "Scale";
|
||||
|
||||
__host__ __device__ Scale(float scale = 1.f) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -783,6 +785,8 @@ struct UnarySqrt
|
||||
|
||||
struct Clamp
|
||||
{
|
||||
static constexpr const char* name = "Clamp";
|
||||
|
||||
Clamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
|
||||
: floor_(floor), ceil_(ceil){};
|
||||
|
||||
@@ -856,6 +860,8 @@ struct Clamp
|
||||
|
||||
struct Relu
|
||||
{
|
||||
static constexpr const char* name = "Relu";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user