mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-30 03:37:38 +00:00
Refactor instance_traits_util and add unit tests tests
This commit is contained in:
committed by
Robin Voetter
parent
ea7f5faa3e
commit
b24e1bf32b
@@ -22,19 +22,12 @@
|
||||
|
||||
namespace ck_tile::reflect::detail {
|
||||
|
||||
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
|
||||
template <typename Seq>
|
||||
struct SequenceToArray;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
struct SequenceToArray<ck::Sequence<Is...>>
|
||||
{
|
||||
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
|
||||
};
|
||||
|
||||
// Convert data types to string names
|
||||
// Implementation detail for type name mapping
|
||||
// This is the single source of truth for supported data types that
|
||||
// returns an empty string to indicate an unsupported type.
|
||||
namespace detail {
|
||||
template <typename T>
|
||||
consteval std::string_view type_name()
|
||||
consteval std::string_view type_name_impl()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck::half_t>)
|
||||
return "fp16";
|
||||
@@ -55,20 +48,36 @@ consteval std::string_view type_name()
|
||||
else if constexpr(std::is_same_v<T, ck::bf8_t>)
|
||||
return "bf8";
|
||||
else
|
||||
static_assert(false, "unknown_type");
|
||||
return std::string_view{}; // Return empty for supported types
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// Convert data types to string names
|
||||
// Fails at compile time for unsupported types
|
||||
template <typename T>
|
||||
consteval std::string_view type_name()
|
||||
{
|
||||
constexpr auto name = detail::type_name_impl<T>();
|
||||
static_assert(!name.empty(), "Unsupported data type");
|
||||
return name;
|
||||
}
|
||||
|
||||
// Convert layout types to string names
|
||||
// Concept that checks if a type is a valid data type
|
||||
// Uses the impl directly to avoid triggering static_assert during concept evaluation
|
||||
template <typename T>
|
||||
concept IsDataType = !detail::type_name_impl<T>().empty();
|
||||
|
||||
// Concept that checks valid layout types
|
||||
template <typename T>
|
||||
concept IsLayoutType = std::is_base_of_v<ck::tensor_layout::BaseTensorLayout, T> && requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
};
|
||||
|
||||
// Convert layout types to string names
|
||||
template <IsLayoutType T>
|
||||
constexpr std::string_view layout_name()
|
||||
{
|
||||
if constexpr(std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> && requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
})
|
||||
return T::name;
|
||||
else
|
||||
static_assert(false,
|
||||
"Layout type must derive from BaseTensorLayout and have name attribute");
|
||||
return T::name;
|
||||
}
|
||||
|
||||
// Convert element-wise operation types to string names
|
||||
@@ -87,64 +96,64 @@ constexpr std::string_view elementwise_op_name()
|
||||
constexpr std::string_view
|
||||
conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecialization spec)
|
||||
{
|
||||
using ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case ConvolutionForwardSpecialization::Default: return "Default";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3";
|
||||
case ConvolutionForwardSpecialization::OddC: return "OddC";
|
||||
case Default: return "Default";
|
||||
case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case Filter3x3: return "Filter3x3";
|
||||
case OddC: return "OddC";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert GemmSpecialization enum to string
|
||||
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
|
||||
{
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case GemmSpecialization::Default: return "Default";
|
||||
case GemmSpecialization::MPadding: return "MPadding";
|
||||
case GemmSpecialization::NPadding: return "NPadding";
|
||||
case GemmSpecialization::KPadding: return "KPadding";
|
||||
case GemmSpecialization::MNPadding: return "MNPadding";
|
||||
case GemmSpecialization::MKPadding: return "MKPadding";
|
||||
case GemmSpecialization::NKPadding: return "NKPadding";
|
||||
case GemmSpecialization::MNKPadding: return "MNKPadding";
|
||||
case GemmSpecialization::OPadding: return "OPadding";
|
||||
case GemmSpecialization::MOPadding: return "MOPadding";
|
||||
case GemmSpecialization::NOPadding: return "NOPadding";
|
||||
case GemmSpecialization::KOPadding: return "KOPadding";
|
||||
case GemmSpecialization::MNOPadding: return "MNOPadding";
|
||||
case GemmSpecialization::MKOPadding: return "MKOPadding";
|
||||
case GemmSpecialization::NKOPadding: return "NKOPadding";
|
||||
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
|
||||
case Default: return "Default";
|
||||
case MPadding: return "MPadding";
|
||||
case NPadding: return "NPadding";
|
||||
case KPadding: return "KPadding";
|
||||
case MNPadding: return "MNPadding";
|
||||
case MKPadding: return "MKPadding";
|
||||
case NKPadding: return "NKPadding";
|
||||
case MNKPadding: return "MNKPadding";
|
||||
case OPadding: return "OPadding";
|
||||
case MOPadding: return "MOPadding";
|
||||
case NOPadding: return "NOPadding";
|
||||
case KOPadding: return "KOPadding";
|
||||
case MNOPadding: return "MNOPadding";
|
||||
case MKOPadding: return "MKOPadding";
|
||||
case NKOPadding: return "NKOPadding";
|
||||
case MNKOPadding: return "MNKOPadding";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BlockGemmPipelineScheduler enum to string
|
||||
constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineScheduler sched)
|
||||
{
|
||||
using ck::BlockGemmPipelineScheduler;
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case BlockGemmPipelineScheduler::Intrawave: return "Intrawave";
|
||||
case BlockGemmPipelineScheduler::Interwave: return "Interwave";
|
||||
case Intrawave: return "Intrawave";
|
||||
case Interwave: return "Interwave";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BlockGemmPipelineVersion enum to string
|
||||
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
|
||||
{
|
||||
using ck::BlockGemmPipelineVersion;
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
switch(ver)
|
||||
{
|
||||
case BlockGemmPipelineVersion::v1: return "v1";
|
||||
case BlockGemmPipelineVersion::v2: return "v2";
|
||||
case BlockGemmPipelineVersion::v3: return "v3";
|
||||
case BlockGemmPipelineVersion::v4: return "v4";
|
||||
case BlockGemmPipelineVersion::v5: return "v5";
|
||||
case v1: return "v1";
|
||||
case v2: return "v2";
|
||||
case v3: return "v3";
|
||||
case v4: return "v4";
|
||||
case v5: return "v5";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,12 +173,138 @@ inline std::string array_to_string(const std::array<T, N>& arr)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Handle ck::Tuple (empty tuple for DsLayout/DsDataType)
|
||||
template <typename T>
|
||||
constexpr std::string_view tuple_name()
|
||||
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
|
||||
template <typename Seq>
|
||||
struct SequenceToArray;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
struct SequenceToArray<ck::Sequence<Is...>>
|
||||
{
|
||||
// For now, just check if it's an empty tuple
|
||||
return "EmptyTuple";
|
||||
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
// Generic helper to build list-like strings (Tuple, Seq, etc.)
|
||||
//
|
||||
// Example output: "Seq(1,2,3)"
|
||||
//
|
||||
// prefix: The list-like container name (e.g. "Tuple" or "Seq")
|
||||
// converter_fn: A callable that converts each element to a string representation
|
||||
// For types: converter_fn should be a template lambda like []<typename U>() { return
|
||||
// type_name<U>(); } For values: converter_fn should be a regular lambda like [](auto value) {
|
||||
// return std::to_string(value); }
|
||||
template <typename ConverterFn, typename... Elements>
|
||||
constexpr std::string build_list_string(std::string_view prefix, const ConverterFn& converter_fn)
|
||||
{
|
||||
if constexpr(sizeof...(Elements) == 0)
|
||||
{
|
||||
return std::string(prefix) + "()";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string result = std::string(prefix) + "(";
|
||||
std::size_t index = 0;
|
||||
((result +=
|
||||
(index++ > 0 ? "," : "") + std::string(converter_fn.template operator()<Elements>())),
|
||||
...);
|
||||
result += ")";
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Overload for value-based lists (sequences)
|
||||
template <typename ConverterFn, auto... Values>
|
||||
constexpr std::string build_list_string_values(std::string_view prefix,
|
||||
const ConverterFn& converter_fn)
|
||||
{
|
||||
if constexpr(sizeof...(Values) == 0)
|
||||
{
|
||||
return std::string(prefix) + "()";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string result = std::string(prefix) + "(";
|
||||
std::size_t index = 0;
|
||||
((result += (index++ > 0 ? "," : "") + converter_fn(Values)), ...);
|
||||
result += ")";
|
||||
return result;
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// Convert ck::Sequence to string representation
|
||||
// Converts a ck::Sequence type to a string in the format "Seq(v1,v2,...,vn)"
|
||||
// where each value is converted using std::to_string.
|
||||
//
|
||||
// Template parameter:
|
||||
// T: Must be a ck::Sequence<...> type
|
||||
//
|
||||
// Constraints:
|
||||
// - Sequence elements must support std::to_string (typically ck::index_t)
|
||||
//
|
||||
// Examples:
|
||||
// sequence_name<ck::Sequence<>>() returns "Seq()"
|
||||
// sequence_name<ck::Sequence<42>>() returns "Seq(42)"
|
||||
// sequence_name<ck::Sequence<1,2,3>>() returns "Seq(1,2,3)"
|
||||
// sequence_name<ck::Sequence<256,128,64>>() returns "Seq(256,128,64)"
|
||||
template <typename T>
|
||||
requires requires { []<ck::index_t... Is>(ck::Sequence<Is...>*) {}(static_cast<T*>(nullptr)); }
|
||||
constexpr std::string sequence_name()
|
||||
{
|
||||
return []<ck::index_t... Is>(ck::Sequence<Is...>*) constexpr {
|
||||
auto to_string_fn = [](auto value) { return std::to_string(value); };
|
||||
return detail::build_list_string_values<decltype(to_string_fn), Is...>("Seq", to_string_fn);
|
||||
}(static_cast<T*>(nullptr));
|
||||
}
|
||||
|
||||
// Convert ck::Tuple to string representation
|
||||
// Converts a ck::Tuple type to a string in the format "Tuple(e1,e2,...,en)"
|
||||
// where each element is converted based on its type (layout names or data type names).
|
||||
//
|
||||
// Template parameter:
|
||||
// T: Must be a ck::Tuple<...> type
|
||||
//
|
||||
// Constraints:
|
||||
// - Empty tuples are supported and return "EmptyTuple"
|
||||
// - All tuple elements must be homogeneous: either all layouts (IsLayoutType) or all data types
|
||||
// (IsDataType)
|
||||
// - Mixed layouts and data types in the same tuple will cause a compile-time error
|
||||
//
|
||||
// Examples:
|
||||
// tuple_name<ck::Tuple<>>() returns "EmptyTuple"
|
||||
// tuple_name<ck::Tuple<ck::tensor_layout::gemm::RowMajor>>() returns "Tuple(RowMajor)"
|
||||
// tuple_name<ck::Tuple<NCHW,NHWC>>() returns "Tuple(NCHW,NHWC)"
|
||||
// tuple_name<ck::Tuple<ck::half_t>>() returns "Tuple(fp16)"
|
||||
// tuple_name<ck::Tuple<ck::half_t,float,double>>() returns "Tuple(fp16,fp32,fp64)"
|
||||
template <typename T>
|
||||
requires requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); }
|
||||
constexpr std::string tuple_name()
|
||||
{
|
||||
return []<typename... Ts>(ck::Tuple<Ts...>*) constexpr {
|
||||
if constexpr(sizeof...(Ts) == 0)
|
||||
{
|
||||
return std::string("EmptyTuple");
|
||||
}
|
||||
else if constexpr((IsLayoutType<Ts> && ...))
|
||||
{
|
||||
// Lambda wrapper for layout_name
|
||||
auto layout_name_fn = []<typename U>() { return layout_name<U>(); };
|
||||
return detail::build_list_string<decltype(layout_name_fn), Ts...>("Tuple",
|
||||
layout_name_fn);
|
||||
}
|
||||
else if constexpr((IsDataType<Ts> && ...))
|
||||
{
|
||||
// Lambda wrapper for type_name
|
||||
auto type_name_fn = []<typename U>() { return type_name<U>(); };
|
||||
return detail::build_list_string<decltype(type_name_fn), Ts...>("Tuple", type_name_fn);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert((IsLayoutType<Ts> && ...) || (IsDataType<Ts> && ...),
|
||||
"Tuple elements must be all layouts or all data types, not mixed");
|
||||
return std::string{}; // unreachable
|
||||
}
|
||||
}(static_cast<T*>(nullptr));
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::detail
|
||||
|
||||
@@ -18,6 +18,7 @@ endfunction()
|
||||
add_ck_builder_test(test_conv_builder
|
||||
test_conv_builder.cpp
|
||||
test_instance_traits.cpp
|
||||
test_instance_traits_util.cpp
|
||||
testing_utils.cpp)
|
||||
|
||||
add_ck_builder_test(test_get_instance_string
|
||||
|
||||
260
experimental/builder/test/test_instance_traits_util.cpp
Normal file
260
experimental/builder/test/test_instance_traits_util.cpp
Normal file
@@ -0,0 +1,260 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits_util.hpp>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
|
||||
namespace ck_tile::reflect::detail {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::IsEmpty;
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceToArrayReturnsEmptyArrayForEmptySequence)
|
||||
{
|
||||
EXPECT_THAT(SequenceToArray<ck::Sequence<>>::value, IsEmpty());
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithSingleElement)
|
||||
{
|
||||
EXPECT_THAT(SequenceToArray<ck::Sequence<42>>::value, ElementsAre(42));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithMultipleElements)
|
||||
{
|
||||
EXPECT_THAT((SequenceToArray<ck::Sequence<1, 2, 3, 4, 5>>::value), ElementsAre(1, 2, 3, 4, 5));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TypeNameReturnsCorrectStrings)
|
||||
{
|
||||
EXPECT_THAT((std::vector<std::string_view>{type_name<ck::half_t>(),
|
||||
type_name<float>(),
|
||||
type_name<double>(),
|
||||
type_name<int8_t>(),
|
||||
type_name<int32_t>(),
|
||||
type_name<ck::bhalf_t>(),
|
||||
type_name<ck::f8_t>(),
|
||||
type_name<ck::bf8_t>()}),
|
||||
ElementsAre("fp16", "fp32", "fp64", "s8", "s32", "bf16", "fp8", "bf8"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForGemmLayouts)
|
||||
{
|
||||
namespace gemm = ck::tensor_layout::gemm;
|
||||
EXPECT_THAT((std::vector<std::string_view>{layout_name<gemm::RowMajor>(),
|
||||
layout_name<gemm::ColumnMajor>(),
|
||||
layout_name<gemm::MFMA>()}),
|
||||
ElementsAre("RowMajor", "ColumnMajor", "MFMA"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForConvLayouts)
|
||||
{
|
||||
namespace conv = ck::tensor_layout::convolution;
|
||||
EXPECT_THAT((std::vector<std::string_view>{// Input tensor layouts (packed)
|
||||
layout_name<conv::NCHW>(),
|
||||
layout_name<conv::NHWC>(),
|
||||
layout_name<conv::NCDHW>(),
|
||||
layout_name<conv::NDHWC>(),
|
||||
// Grouped input layouts
|
||||
layout_name<conv::GNCHW>(),
|
||||
layout_name<conv::GNHWC>(),
|
||||
// Weight tensor layouts
|
||||
layout_name<conv::KCYX>(),
|
||||
layout_name<conv::KYXC>(),
|
||||
layout_name<conv::GKCYX>(),
|
||||
layout_name<conv::GKYXC>(),
|
||||
// Output tensor layouts
|
||||
layout_name<conv::NKHW>(),
|
||||
layout_name<conv::NHWK>(),
|
||||
layout_name<conv::GNKHW>(),
|
||||
layout_name<conv::GNHWK>(),
|
||||
// Strided layouts
|
||||
layout_name<conv::G_NHW_C>(),
|
||||
layout_name<conv::G_K_YX_C>(),
|
||||
layout_name<conv::G_NHW_K>(),
|
||||
// Bias layouts
|
||||
layout_name<conv::G_C>(),
|
||||
layout_name<conv::G_K>()}),
|
||||
ElementsAre("NCHW",
|
||||
"NHWC",
|
||||
"NCDHW",
|
||||
"NDHWC",
|
||||
"GNCHW",
|
||||
"GNHWC",
|
||||
"KCYX",
|
||||
"KYXC",
|
||||
"GKCYX",
|
||||
"GKYXC",
|
||||
"NKHW",
|
||||
"NHWK",
|
||||
"GNKHW",
|
||||
"GNHWK",
|
||||
"G_NHW_C",
|
||||
"G_K_YX_C",
|
||||
"G_NHW_K",
|
||||
"G_C",
|
||||
"G_K"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, ElementwiseOpNameReturnsCorrectStrings)
|
||||
{
|
||||
namespace element_wise = ck::tensor_operation::element_wise;
|
||||
EXPECT_THAT((std::vector<std::string_view>{
|
||||
elementwise_op_name<element_wise::PassThrough>(),
|
||||
elementwise_op_name<element_wise::Scale>(),
|
||||
elementwise_op_name<element_wise::Bilinear>(),
|
||||
elementwise_op_name<element_wise::Add>(),
|
||||
elementwise_op_name<element_wise::AddRelu>(),
|
||||
elementwise_op_name<element_wise::Relu>(),
|
||||
elementwise_op_name<element_wise::BiasNormalizeInInferClamp>(),
|
||||
elementwise_op_name<element_wise::Clamp>(),
|
||||
elementwise_op_name<element_wise::AddClamp>()}),
|
||||
ElementsAre("PassThrough",
|
||||
"Scale",
|
||||
"Bilinear",
|
||||
"Add",
|
||||
"AddRelu",
|
||||
"Relu",
|
||||
"BiasNormalizeInInferClamp",
|
||||
"Clamp",
|
||||
"AddClamp"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, ConvFwdSpecNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
EXPECT_THAT(
|
||||
(std::vector<std::string_view>{conv_fwd_spec_name(Default),
|
||||
conv_fwd_spec_name(Filter1x1Stride1Pad0),
|
||||
conv_fwd_spec_name(Filter1x1Pad0),
|
||||
conv_fwd_spec_name(Filter3x3),
|
||||
conv_fwd_spec_name(OddC)}),
|
||||
ElementsAre("Default", "Filter1x1Stride1Pad0", "Filter1x1Pad0", "Filter3x3", "OddC"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, GemmSpecNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
EXPECT_THAT((std::vector<std::string_view>{gemm_spec_name(Default),
|
||||
gemm_spec_name(MPadding),
|
||||
gemm_spec_name(NPadding),
|
||||
gemm_spec_name(KPadding),
|
||||
gemm_spec_name(MNPadding),
|
||||
gemm_spec_name(MKPadding),
|
||||
gemm_spec_name(NKPadding),
|
||||
gemm_spec_name(MNKPadding),
|
||||
gemm_spec_name(OPadding),
|
||||
gemm_spec_name(MOPadding),
|
||||
gemm_spec_name(NOPadding),
|
||||
gemm_spec_name(KOPadding),
|
||||
gemm_spec_name(MNOPadding),
|
||||
gemm_spec_name(MKOPadding),
|
||||
gemm_spec_name(NKOPadding),
|
||||
gemm_spec_name(MNKOPadding)}),
|
||||
ElementsAre("Default",
|
||||
"MPadding",
|
||||
"NPadding",
|
||||
"KPadding",
|
||||
"MNPadding",
|
||||
"MKPadding",
|
||||
"NKPadding",
|
||||
"MNKPadding",
|
||||
"OPadding",
|
||||
"MOPadding",
|
||||
"NOPadding",
|
||||
"KOPadding",
|
||||
"MNOPadding",
|
||||
"MKOPadding",
|
||||
"NKOPadding",
|
||||
"MNKOPadding"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, PipelineSchedulerNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
EXPECT_THAT((std::vector<std::string_view>{pipeline_scheduler_name(Intrawave),
|
||||
pipeline_scheduler_name(Interwave)}),
|
||||
ElementsAre("Intrawave", "Interwave"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
EXPECT_THAT((std::vector<std::string_view>{pipeline_version_name(v1),
|
||||
pipeline_version_name(v2),
|
||||
pipeline_version_name(v3),
|
||||
pipeline_version_name(v4),
|
||||
pipeline_version_name(v5)}),
|
||||
ElementsAre("v1", "v2", "v3", "v4", "v5"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<>>(), "EmptyTuple");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleLayout)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW>>(), "Tuple(NCHW)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoLayouts)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW,
|
||||
ck::tensor_layout::convolution::NHWC>>()),
|
||||
"Tuple(NCHW,NHWC)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeLayouts)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::NKHW>>()),
|
||||
"Tuple(NCHW,NHWC,NKHW)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleDataType)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<ck::half_t>>(), "Tuple(fp16)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoDataTypes)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::half_t, float>>()), "Tuple(fp16,fp32)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeDataTypes)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::half_t, float, double>>()), "Tuple(fp16,fp32,fp64)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForEmptySequence)
|
||||
{
|
||||
EXPECT_EQ(sequence_name<ck::Sequence<>>(), "Seq()");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForSingleValueSequence)
|
||||
{
|
||||
EXPECT_EQ(sequence_name<ck::Sequence<42>>(), "Seq(42)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForTwoValueSequence)
|
||||
{
|
||||
EXPECT_EQ((sequence_name<ck::Sequence<1, 2>>()), "Seq(1,2)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForMultipleValueSequence)
|
||||
{
|
||||
EXPECT_EQ((sequence_name<ck::Sequence<256, 128, 64, 32, 16>>()), "Seq(256,128,64,32,16)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace ck_tile::reflect::detail
|
||||
@@ -12,6 +12,8 @@ namespace element_wise {
|
||||
|
||||
struct Add
|
||||
{
|
||||
static constexpr const char* name = "Add";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
@@ -279,6 +281,8 @@ struct Subtract
|
||||
|
||||
struct Bilinear
|
||||
{
|
||||
static constexpr const char* name = "Bilinear";
|
||||
|
||||
Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
@@ -353,6 +357,8 @@ struct Bilinear
|
||||
|
||||
struct AddClamp
|
||||
{
|
||||
static constexpr const char* name = "AddClamp";
|
||||
|
||||
AddClamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
|
||||
: floor_(floor), ceil_(ceil){};
|
||||
|
||||
@@ -442,6 +448,8 @@ struct AddClamp
|
||||
|
||||
struct AddRelu
|
||||
{
|
||||
static constexpr const char* name = "AddRelu";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
|
||||
@@ -565,6 +565,8 @@ struct NormalizeInInfer
|
||||
// used by Conv+Bias+BatchNorm+Clamp inference
|
||||
struct BiasNormalizeInInferClamp
|
||||
{
|
||||
static constexpr const char* name = "BiasNormalizeInInferClamp";
|
||||
|
||||
BiasNormalizeInInferClamp(float floor = 0.f,
|
||||
float ceil = NumericLimits<float>::Max(),
|
||||
float epsilon = 1e-4)
|
||||
|
||||
@@ -332,6 +332,8 @@ struct PassThroughPack2
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
static constexpr const char* name = "PassThrough";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -618,6 +620,8 @@ struct ConvertF8RNE
|
||||
|
||||
struct Scale
|
||||
{
|
||||
static constexpr const char* name = "Scale";
|
||||
|
||||
__host__ __device__ Scale(float scale = 1.f) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -781,6 +785,8 @@ struct UnarySqrt
|
||||
|
||||
struct Clamp
|
||||
{
|
||||
static constexpr const char* name = "Clamp";
|
||||
|
||||
Clamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
|
||||
: floor_(floor), ceil_(ceil){};
|
||||
|
||||
@@ -854,6 +860,8 @@ struct Clamp
|
||||
|
||||
struct Relu
|
||||
{
|
||||
static constexpr const char* name = "Relu";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user