mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-16 10:59:55 +00:00
Merge commit '1c17bae816edc44c32ee9d1a19d79d768fd1be13' into develop
This commit is contained in:
@@ -23,19 +23,12 @@
|
||||
|
||||
namespace ck_tile::reflect::detail {
|
||||
|
||||
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
|
||||
template <typename Seq>
|
||||
struct SequenceToArray;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
struct SequenceToArray<ck::Sequence<Is...>>
|
||||
{
|
||||
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
|
||||
};
|
||||
|
||||
// Convert data types to string names
|
||||
// Implementation detail for type name mapping
|
||||
// This is the single source of truth for supported data types that
|
||||
// returns an empty string to indicate an unsupported type.
|
||||
namespace impl {
|
||||
template <typename T>
|
||||
consteval std::string_view type_name()
|
||||
consteval std::string_view type_name_impl()
|
||||
{
|
||||
if constexpr(std::is_same_v<T, ck::half_t>)
|
||||
return "fp16";
|
||||
@@ -56,22 +49,38 @@ consteval std::string_view type_name()
|
||||
else if constexpr(std::is_same_v<T, ck::bf8_t>)
|
||||
return "bf8";
|
||||
else
|
||||
static_assert(false, "unknown_type");
|
||||
return std::string_view{}; // Return empty for supported types
|
||||
}
|
||||
} // namespace impl
|
||||
|
||||
// Convert data types to string names
|
||||
// Fails at compile time for unsupported types
|
||||
template <typename T>
|
||||
consteval std::string_view type_name()
|
||||
{
|
||||
constexpr auto name = impl::type_name_impl<T>();
|
||||
static_assert(!name.empty(), "Unsupported data type");
|
||||
return name;
|
||||
}
|
||||
|
||||
// Convert layout types to string names
|
||||
// Concept that checks if a type is a valid data type
|
||||
// Uses the impl directly to avoid triggering static_assert during concept evaluation
|
||||
template <typename T>
|
||||
concept IsDataType = !impl::type_name_impl<T>().empty();
|
||||
|
||||
// Concept that checks valid layout types
|
||||
template <typename T>
|
||||
concept IsLayoutType = (std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> ||
|
||||
std::is_base_of_v<ck::tensor_layout::BaseTensorLayout, T>) &&
|
||||
requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
};
|
||||
|
||||
// Convert layout types to string names
|
||||
template <IsLayoutType T>
|
||||
constexpr std::string_view layout_name()
|
||||
{
|
||||
if constexpr((std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> ||
|
||||
std::is_base_of_v<ck::tensor_layout::BaseTensorLayout, T>) &&
|
||||
requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
})
|
||||
return T::name;
|
||||
else
|
||||
static_assert(false,
|
||||
"Layout type must derive from BaseTensorLayout and have name attribute");
|
||||
return T::name;
|
||||
}
|
||||
|
||||
// Convert element-wise operation types to string names
|
||||
@@ -90,64 +99,64 @@ constexpr std::string_view elementwise_op_name()
|
||||
constexpr std::string_view
|
||||
conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecialization spec)
|
||||
{
|
||||
using ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case ConvolutionForwardSpecialization::Default: return "Default";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3";
|
||||
case ConvolutionForwardSpecialization::OddC: return "OddC";
|
||||
case Default: return "Default";
|
||||
case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
|
||||
case Filter1x1Pad0: return "Filter1x1Pad0";
|
||||
case Filter3x3: return "Filter3x3";
|
||||
case OddC: return "OddC";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert GemmSpecialization enum to string
|
||||
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
|
||||
{
|
||||
using ck::tensor_operation::device::GemmSpecialization;
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
switch(spec)
|
||||
{
|
||||
case GemmSpecialization::Default: return "Default";
|
||||
case GemmSpecialization::MPadding: return "MPadding";
|
||||
case GemmSpecialization::NPadding: return "NPadding";
|
||||
case GemmSpecialization::KPadding: return "KPadding";
|
||||
case GemmSpecialization::MNPadding: return "MNPadding";
|
||||
case GemmSpecialization::MKPadding: return "MKPadding";
|
||||
case GemmSpecialization::NKPadding: return "NKPadding";
|
||||
case GemmSpecialization::MNKPadding: return "MNKPadding";
|
||||
case GemmSpecialization::OPadding: return "OPadding";
|
||||
case GemmSpecialization::MOPadding: return "MOPadding";
|
||||
case GemmSpecialization::NOPadding: return "NOPadding";
|
||||
case GemmSpecialization::KOPadding: return "KOPadding";
|
||||
case GemmSpecialization::MNOPadding: return "MNOPadding";
|
||||
case GemmSpecialization::MKOPadding: return "MKOPadding";
|
||||
case GemmSpecialization::NKOPadding: return "NKOPadding";
|
||||
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
|
||||
case Default: return "Default";
|
||||
case MPadding: return "MPadding";
|
||||
case NPadding: return "NPadding";
|
||||
case KPadding: return "KPadding";
|
||||
case MNPadding: return "MNPadding";
|
||||
case MKPadding: return "MKPadding";
|
||||
case NKPadding: return "NKPadding";
|
||||
case MNKPadding: return "MNKPadding";
|
||||
case OPadding: return "OPadding";
|
||||
case MOPadding: return "MOPadding";
|
||||
case NOPadding: return "NOPadding";
|
||||
case KOPadding: return "KOPadding";
|
||||
case MNOPadding: return "MNOPadding";
|
||||
case MKOPadding: return "MKOPadding";
|
||||
case NKOPadding: return "NKOPadding";
|
||||
case MNKOPadding: return "MNKOPadding";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BlockGemmPipelineScheduler enum to string
|
||||
constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineScheduler sched)
|
||||
{
|
||||
using ck::BlockGemmPipelineScheduler;
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
switch(sched)
|
||||
{
|
||||
case BlockGemmPipelineScheduler::Intrawave: return "Intrawave";
|
||||
case BlockGemmPipelineScheduler::Interwave: return "Interwave";
|
||||
case Intrawave: return "Intrawave";
|
||||
case Interwave: return "Interwave";
|
||||
}
|
||||
}
|
||||
|
||||
// Convert BlockGemmPipelineVersion enum to string
|
||||
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
|
||||
{
|
||||
using ck::BlockGemmPipelineVersion;
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
switch(ver)
|
||||
{
|
||||
case BlockGemmPipelineVersion::v1: return "v1";
|
||||
case BlockGemmPipelineVersion::v2: return "v2";
|
||||
case BlockGemmPipelineVersion::v3: return "v3";
|
||||
case BlockGemmPipelineVersion::v4: return "v4";
|
||||
case BlockGemmPipelineVersion::v5: return "v5";
|
||||
case v1: return "v1";
|
||||
case v2: return "v2";
|
||||
case v3: return "v3";
|
||||
case v4: return "v4";
|
||||
case v5: return "v5";
|
||||
}
|
||||
}
|
||||
|
||||
@@ -167,12 +176,138 @@ inline std::string array_to_string(const std::array<T, N>& arr)
|
||||
return oss.str();
|
||||
}
|
||||
|
||||
// Handle ck::Tuple (empty tuple for DsLayout/DsDataType)
|
||||
template <typename T>
|
||||
constexpr std::string_view tuple_name()
|
||||
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
|
||||
template <typename Seq>
|
||||
struct SequenceToArray;
|
||||
|
||||
template <ck::index_t... Is>
|
||||
struct SequenceToArray<ck::Sequence<Is...>>
|
||||
{
|
||||
// For now, just check if it's an empty tuple
|
||||
return "EmptyTuple";
|
||||
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
|
||||
};
|
||||
|
||||
namespace detail {
|
||||
// Generic helper to build list-like strings (Tuple, Seq, etc.)
|
||||
//
|
||||
// Example output: "Seq(1,2,3)"
|
||||
//
|
||||
// prefix: The list-like container name (e.g. "Tuple" or "Seq")
|
||||
// converter_fn: A callable that converts each element to a string representation
|
||||
// For types: converter_fn should be a template lambda like []<typename U>() { return
|
||||
// type_name<U>(); } For values: converter_fn should be a regular lambda like [](auto value) {
|
||||
// return std::to_string(value); }
|
||||
template <typename ConverterFn, typename... Elements>
|
||||
constexpr std::string build_list_string(std::string_view prefix, const ConverterFn& converter_fn)
|
||||
{
|
||||
if constexpr(sizeof...(Elements) == 0)
|
||||
{
|
||||
return std::string(prefix) + "()";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string result = std::string(prefix) + "(";
|
||||
std::size_t index = 0;
|
||||
((result +=
|
||||
(index++ > 0 ? "," : "") + std::string(converter_fn.template operator()<Elements>())),
|
||||
...);
|
||||
result += ")";
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// Overload for value-based lists (sequences)
|
||||
template <typename ConverterFn, auto... Values>
|
||||
constexpr std::string build_list_string_values(std::string_view prefix,
|
||||
const ConverterFn& converter_fn)
|
||||
{
|
||||
if constexpr(sizeof...(Values) == 0)
|
||||
{
|
||||
return std::string(prefix) + "()";
|
||||
}
|
||||
else
|
||||
{
|
||||
std::string result = std::string(prefix) + "(";
|
||||
std::size_t index = 0;
|
||||
((result += (index++ > 0 ? "," : "") + converter_fn(Values)), ...);
|
||||
result += ")";
|
||||
return result;
|
||||
}
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
// Convert ck::Sequence to string representation
|
||||
// Converts a ck::Sequence type to a string in the format "Seq(v1,v2,...,vn)"
|
||||
// where each value is converted using std::to_string.
|
||||
//
|
||||
// Template parameter:
|
||||
// T: Must be a ck::Sequence<...> type
|
||||
//
|
||||
// Constraints:
|
||||
// - Sequence elements must support std::to_string (typically ck::index_t)
|
||||
//
|
||||
// Examples:
|
||||
// sequence_name<ck::Sequence<>>() returns "Seq()"
|
||||
// sequence_name<ck::Sequence<42>>() returns "Seq(42)"
|
||||
// sequence_name<ck::Sequence<1,2,3>>() returns "Seq(1,2,3)"
|
||||
// sequence_name<ck::Sequence<256,128,64>>() returns "Seq(256,128,64)"
|
||||
template <typename T>
|
||||
requires requires { []<ck::index_t... Is>(ck::Sequence<Is...>*) {}(static_cast<T*>(nullptr)); }
|
||||
constexpr std::string sequence_name()
|
||||
{
|
||||
return []<ck::index_t... Is>(ck::Sequence<Is...>*) constexpr {
|
||||
auto to_string_fn = [](auto value) { return std::to_string(value); };
|
||||
return detail::build_list_string_values<decltype(to_string_fn), Is...>("Seq", to_string_fn);
|
||||
}(static_cast<T*>(nullptr));
|
||||
}
|
||||
|
||||
// Convert ck::Tuple to string representation
|
||||
// Converts a ck::Tuple type to a string in the format "Tuple(e1,e2,...,en)"
|
||||
// where each element is converted based on its type (layout names or data type names).
|
||||
//
|
||||
// Template parameter:
|
||||
// T: Must be a ck::Tuple<...> type
|
||||
//
|
||||
// Constraints:
|
||||
// - Empty tuples are supported and return "EmptyTuple"
|
||||
// - All tuple elements must be homogeneous: either all layouts (IsLayoutType) or all data types
|
||||
// (IsDataType)
|
||||
// - Mixed layouts and data types in the same tuple will cause a compile-time error
|
||||
//
|
||||
// Examples:
|
||||
// tuple_name<ck::Tuple<>>() returns "EmptyTuple"
|
||||
// tuple_name<ck::Tuple<ck::tensor_layout::gemm::RowMajor>>() returns "Tuple(RowMajor)"
|
||||
// tuple_name<ck::Tuple<NCHW,NHWC>>() returns "Tuple(NCHW,NHWC)"
|
||||
// tuple_name<ck::Tuple<ck::half_t>>() returns "Tuple(fp16)"
|
||||
// tuple_name<ck::Tuple<ck::half_t,float,double>>() returns "Tuple(fp16,fp32,fp64)"
|
||||
template <typename T>
|
||||
requires requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); }
|
||||
constexpr std::string tuple_name()
|
||||
{
|
||||
return []<typename... Ts>(ck::Tuple<Ts...>*) constexpr {
|
||||
if constexpr(sizeof...(Ts) == 0)
|
||||
{
|
||||
return std::string("EmptyTuple");
|
||||
}
|
||||
else if constexpr((IsLayoutType<Ts> && ...))
|
||||
{
|
||||
// Lambda wrapper for layout_name
|
||||
auto layout_name_fn = []<typename U>() { return layout_name<U>(); };
|
||||
return detail::build_list_string<decltype(layout_name_fn), Ts...>("Tuple",
|
||||
layout_name_fn);
|
||||
}
|
||||
else if constexpr((IsDataType<Ts> && ...))
|
||||
{
|
||||
// Lambda wrapper for type_name
|
||||
auto type_name_fn = []<typename U>() { return type_name<U>(); };
|
||||
return detail::build_list_string<decltype(type_name_fn), Ts...>("Tuple", type_name_fn);
|
||||
}
|
||||
else
|
||||
{
|
||||
static_assert((IsLayoutType<Ts> && ...) || (IsDataType<Ts> && ...),
|
||||
"Tuple elements must be all layouts or all data types, not mixed");
|
||||
return std::string{}; // unreachable
|
||||
}
|
||||
}(static_cast<T*>(nullptr));
|
||||
}
|
||||
|
||||
} // namespace ck_tile::reflect::detail
|
||||
|
||||
@@ -16,20 +16,13 @@ function(add_ck_builder_test test_name)
|
||||
target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock)
|
||||
endfunction()
|
||||
|
||||
# The test_conv_builder target has all the unit tests (each test should run < 10 ms)
|
||||
add_ck_builder_test(test_conv_builder
|
||||
test_conv_builder.cpp
|
||||
test_instance_traits.cpp
|
||||
test_instance_traits_util.cpp
|
||||
testing_utils.cpp)
|
||||
|
||||
# Testing the virtual GetInstanceString methods requires kernel compilation.
|
||||
add_ck_builder_test(test_get_instance_string
|
||||
test_get_instance_string.cpp)
|
||||
|
||||
add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp)
|
||||
|
||||
add_ck_builder_test(test_ckb_build_fwd_instances
|
||||
conv/test_ckb_conv_fwd_2d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_2d_fp32.cpp
|
||||
conv/test_ckb_conv_fwd_3d_bf16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp16.cpp
|
||||
conv/test_ckb_conv_fwd_3d_fp32.cpp)
|
||||
263
experimental/builder/test/test_instance_traits_util.cpp
Normal file
263
experimental/builder/test/test_instance_traits_util.cpp
Normal file
@@ -0,0 +1,263 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <ck_tile/builder/reflect/instance_traits_util.hpp>
|
||||
#include <ck/utility/data_type.hpp>
|
||||
#include <ck/utility/sequence.hpp>
|
||||
#include <ck/utility/blkgemmpipe_scheduler.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
|
||||
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
|
||||
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
|
||||
|
||||
namespace ck_tile::reflect::detail {
|
||||
namespace {
|
||||
|
||||
using ::testing::ElementsAre;
|
||||
using ::testing::IsEmpty;
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceToArrayReturnsEmptyArrayForEmptySequence)
|
||||
{
|
||||
EXPECT_THAT(SequenceToArray<ck::Sequence<>>::value, IsEmpty());
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithSingleElement)
|
||||
{
|
||||
EXPECT_THAT(SequenceToArray<ck::Sequence<42>>::value, ElementsAre(42));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithMultipleElements)
|
||||
{
|
||||
EXPECT_THAT((SequenceToArray<ck::Sequence<1, 2, 3, 4, 5>>::value), ElementsAre(1, 2, 3, 4, 5));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TypeNameReturnsCorrectStrings)
|
||||
{
|
||||
EXPECT_THAT((std::vector<std::string_view>{type_name<ck::half_t>(),
|
||||
type_name<float>(),
|
||||
type_name<double>(),
|
||||
type_name<int8_t>(),
|
||||
type_name<int32_t>(),
|
||||
type_name<ck::bhalf_t>(),
|
||||
type_name<ck::f8_t>(),
|
||||
type_name<ck::bf8_t>()}),
|
||||
ElementsAre("fp16", "fp32", "fp64", "s8", "s32", "bf16", "fp8", "bf8"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForGemmLayouts)
|
||||
{
|
||||
namespace gemm = ck::tensor_layout::gemm;
|
||||
EXPECT_THAT((std::vector<std::string_view>{layout_name<gemm::RowMajor>(),
|
||||
layout_name<gemm::ColumnMajor>(),
|
||||
layout_name<gemm::MFMA>()}),
|
||||
ElementsAre("RowMajor", "ColumnMajor", "MFMA"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForConvLayouts)
|
||||
{
|
||||
namespace conv = ck::tensor_layout::convolution;
|
||||
EXPECT_THAT((std::vector<std::string_view>{
|
||||
// Input tensor layouts
|
||||
// TODO(deprecated): Remove non-grouped layouts once instances are removed.
|
||||
layout_name<conv::NCHW>(),
|
||||
layout_name<conv::NHWC>(),
|
||||
layout_name<conv::NCDHW>(),
|
||||
layout_name<conv::NDHWC>(),
|
||||
// Grouped input layouts
|
||||
layout_name<conv::GNCHW>(),
|
||||
layout_name<conv::GNHWC>(),
|
||||
// Weight tensor layouts
|
||||
layout_name<conv::KCYX>(),
|
||||
layout_name<conv::KYXC>(),
|
||||
layout_name<conv::GKCYX>(),
|
||||
layout_name<conv::GKYXC>(),
|
||||
// Output tensor layouts
|
||||
layout_name<conv::NKHW>(),
|
||||
layout_name<conv::NHWK>(),
|
||||
layout_name<conv::GNKHW>(),
|
||||
layout_name<conv::GNHWK>(),
|
||||
// Strided layouts
|
||||
// TODO(deprecated): Remove strided layouts once instances are removed.
|
||||
layout_name<conv::G_NHW_C>(),
|
||||
layout_name<conv::G_K_YX_C>(),
|
||||
layout_name<conv::G_NHW_K>(),
|
||||
// Bias layouts
|
||||
layout_name<conv::G_C>(),
|
||||
layout_name<conv::G_K>()}),
|
||||
ElementsAre("NCHW",
|
||||
"NHWC",
|
||||
"NCDHW",
|
||||
"NDHWC",
|
||||
"GNCHW",
|
||||
"GNHWC",
|
||||
"KCYX",
|
||||
"KYXC",
|
||||
"GKCYX",
|
||||
"GKYXC",
|
||||
"NKHW",
|
||||
"NHWK",
|
||||
"GNKHW",
|
||||
"GNHWK",
|
||||
"G_NHW_C",
|
||||
"G_K_YX_C",
|
||||
"G_NHW_K",
|
||||
"G_C",
|
||||
"G_K"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, ElementwiseOpNameReturnsCorrectStrings)
|
||||
{
|
||||
namespace element_wise = ck::tensor_operation::element_wise;
|
||||
EXPECT_THAT((std::vector<std::string_view>{
|
||||
elementwise_op_name<element_wise::PassThrough>(),
|
||||
elementwise_op_name<element_wise::Scale>(),
|
||||
elementwise_op_name<element_wise::Bilinear>(),
|
||||
elementwise_op_name<element_wise::Add>(),
|
||||
elementwise_op_name<element_wise::AddRelu>(),
|
||||
elementwise_op_name<element_wise::Relu>(),
|
||||
elementwise_op_name<element_wise::BiasNormalizeInInferClamp>(),
|
||||
elementwise_op_name<element_wise::Clamp>(),
|
||||
elementwise_op_name<element_wise::AddClamp>()}),
|
||||
ElementsAre("PassThrough",
|
||||
"Scale",
|
||||
"Bilinear",
|
||||
"Add",
|
||||
"AddRelu",
|
||||
"Relu",
|
||||
"BiasNormalizeInInferClamp",
|
||||
"Clamp",
|
||||
"AddClamp"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, ConvFwdSpecNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
|
||||
EXPECT_THAT(
|
||||
(std::vector<std::string_view>{conv_fwd_spec_name(Default),
|
||||
conv_fwd_spec_name(Filter1x1Stride1Pad0),
|
||||
conv_fwd_spec_name(Filter1x1Pad0),
|
||||
conv_fwd_spec_name(Filter3x3),
|
||||
conv_fwd_spec_name(OddC)}),
|
||||
ElementsAre("Default", "Filter1x1Stride1Pad0", "Filter1x1Pad0", "Filter3x3", "OddC"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, GemmSpecNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::tensor_operation::device::GemmSpecialization;
|
||||
EXPECT_THAT((std::vector<std::string_view>{gemm_spec_name(Default),
|
||||
gemm_spec_name(MPadding),
|
||||
gemm_spec_name(NPadding),
|
||||
gemm_spec_name(KPadding),
|
||||
gemm_spec_name(MNPadding),
|
||||
gemm_spec_name(MKPadding),
|
||||
gemm_spec_name(NKPadding),
|
||||
gemm_spec_name(MNKPadding),
|
||||
gemm_spec_name(OPadding),
|
||||
gemm_spec_name(MOPadding),
|
||||
gemm_spec_name(NOPadding),
|
||||
gemm_spec_name(KOPadding),
|
||||
gemm_spec_name(MNOPadding),
|
||||
gemm_spec_name(MKOPadding),
|
||||
gemm_spec_name(NKOPadding),
|
||||
gemm_spec_name(MNKOPadding)}),
|
||||
ElementsAre("Default",
|
||||
"MPadding",
|
||||
"NPadding",
|
||||
"KPadding",
|
||||
"MNPadding",
|
||||
"MKPadding",
|
||||
"NKPadding",
|
||||
"MNKPadding",
|
||||
"OPadding",
|
||||
"MOPadding",
|
||||
"NOPadding",
|
||||
"KOPadding",
|
||||
"MNOPadding",
|
||||
"MKOPadding",
|
||||
"NKOPadding",
|
||||
"MNKOPadding"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, PipelineSchedulerNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::BlockGemmPipelineScheduler;
|
||||
EXPECT_THAT((std::vector<std::string_view>{pipeline_scheduler_name(Intrawave),
|
||||
pipeline_scheduler_name(Interwave)}),
|
||||
ElementsAre("Intrawave", "Interwave"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings)
|
||||
{
|
||||
using enum ck::BlockGemmPipelineVersion;
|
||||
EXPECT_THAT((std::vector<std::string_view>{pipeline_version_name(v1),
|
||||
pipeline_version_name(v2),
|
||||
pipeline_version_name(v3),
|
||||
pipeline_version_name(v4),
|
||||
pipeline_version_name(v5)}),
|
||||
ElementsAre("v1", "v2", "v3", "v4", "v5"));
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<>>(), "EmptyTuple");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleLayout)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW>>(), "Tuple(NCHW)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoLayouts)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW,
|
||||
ck::tensor_layout::convolution::NHWC>>()),
|
||||
"Tuple(NCHW,NHWC)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeLayouts)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW,
|
||||
ck::tensor_layout::convolution::NHWC,
|
||||
ck::tensor_layout::convolution::NKHW>>()),
|
||||
"Tuple(NCHW,NHWC,NKHW)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleDataType)
|
||||
{
|
||||
EXPECT_EQ(tuple_name<ck::Tuple<ck::half_t>>(), "Tuple(fp16)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoDataTypes)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::half_t, float>>()), "Tuple(fp16,fp32)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeDataTypes)
|
||||
{
|
||||
EXPECT_EQ((tuple_name<ck::Tuple<ck::half_t, float, double>>()), "Tuple(fp16,fp32,fp64)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForEmptySequence)
|
||||
{
|
||||
EXPECT_EQ(sequence_name<ck::Sequence<>>(), "Seq()");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForSingleValueSequence)
|
||||
{
|
||||
EXPECT_EQ(sequence_name<ck::Sequence<42>>(), "Seq(42)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForTwoValueSequence)
|
||||
{
|
||||
EXPECT_EQ((sequence_name<ck::Sequence<1, 2>>()), "Seq(1,2)");
|
||||
}
|
||||
|
||||
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForMultipleValueSequence)
|
||||
{
|
||||
EXPECT_EQ((sequence_name<ck::Sequence<256, 128, 64, 32, 16>>()), "Seq(256,128,64,32,16)");
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace ck_tile::reflect::detail
|
||||
@@ -12,6 +12,8 @@ namespace element_wise {
|
||||
|
||||
struct Add
|
||||
{
|
||||
static constexpr const char* name = "Add";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
@@ -94,6 +96,8 @@ struct Add
|
||||
|
||||
struct Max
|
||||
{
|
||||
static constexpr const char* name = "Max";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
@@ -105,6 +109,8 @@ struct Max
|
||||
|
||||
struct Min
|
||||
{
|
||||
static constexpr const char* name = "Min";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
|
||||
{
|
||||
@@ -116,6 +122,8 @@ struct Min
|
||||
|
||||
struct Multiply
|
||||
{
|
||||
static constexpr const char* name = "Multiply";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
@@ -208,6 +216,8 @@ struct Multiply
|
||||
|
||||
struct ScaleAdd
|
||||
{
|
||||
static constexpr const char* name = "ScaleAdd";
|
||||
|
||||
__host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
@@ -235,6 +245,8 @@ struct ScaleAdd
|
||||
|
||||
struct Subtract
|
||||
{
|
||||
static constexpr const char* name = "Subtract";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
|
||||
|
||||
@@ -279,6 +291,8 @@ struct Subtract
|
||||
|
||||
struct Bilinear
|
||||
{
|
||||
static constexpr const char* name = "Bilinear";
|
||||
|
||||
Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
@@ -353,6 +367,8 @@ struct Bilinear
|
||||
|
||||
struct AddClamp
|
||||
{
|
||||
static constexpr const char* name = "AddClamp";
|
||||
|
||||
AddClamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
|
||||
: floor_(floor), ceil_(ceil){};
|
||||
|
||||
@@ -442,6 +458,8 @@ struct AddClamp
|
||||
|
||||
struct AddRelu
|
||||
{
|
||||
static constexpr const char* name = "AddRelu";
|
||||
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
@@ -523,6 +541,8 @@ struct AddRelu
|
||||
|
||||
struct AddHardswish
|
||||
{
|
||||
static constexpr const char* name = "AddHardswish";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
|
||||
|
||||
@@ -560,6 +580,8 @@ struct AddHardswish
|
||||
// E = FastGelu(C + D)
|
||||
struct AddFastGelu
|
||||
{
|
||||
static constexpr const char* name = "AddFastGelu";
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
@@ -625,6 +647,8 @@ struct AddFastGelu
|
||||
// E = MultiplyFastGelu(C + D)
|
||||
struct MultiplyFastGelu
|
||||
{
|
||||
static constexpr const char* name = "MultiplyFastGelu";
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
@@ -690,6 +714,8 @@ struct MultiplyFastGelu
|
||||
// E = Silu(C + D)
|
||||
struct AddSilu
|
||||
{
|
||||
static constexpr const char* name = "AddSilu";
|
||||
|
||||
template <typename E, typename C, typename D>
|
||||
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
|
||||
|
||||
@@ -740,6 +766,8 @@ struct AddSilu
|
||||
|
||||
struct ConvScaleAdd
|
||||
{
|
||||
static constexpr const char* name = "ConvScaleAdd";
|
||||
|
||||
__host__ __device__ ConvScaleAdd(float scale_in = 1.f,
|
||||
float scale_wei = 1.f,
|
||||
float scale_out = 1.f)
|
||||
|
||||
@@ -13,6 +13,8 @@ namespace element_wise {
|
||||
template <typename... UnaryOpsSet>
|
||||
struct UnaryCombinedOp
|
||||
{
|
||||
static constexpr const char* name = "UnaryCombinedOp";
|
||||
|
||||
__host__ __device__ UnaryCombinedOp() : unary_ops_() {}
|
||||
|
||||
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
|
||||
@@ -33,6 +35,8 @@ struct UnaryCombinedOp
|
||||
template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
|
||||
struct BinaryWithUnaryCombinedOp
|
||||
{
|
||||
static constexpr const char* name = "BinaryWithUnaryCombinedOp";
|
||||
|
||||
__host__ __device__ BinaryWithUnaryCombinedOp() : binary_op_(), unary_op0_(), unary_op1_() {}
|
||||
|
||||
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
|
||||
@@ -66,6 +70,8 @@ template <typename BinaryOp0,
|
||||
typename UnaryOp2>
|
||||
struct TrinaryWithUnaryCombinedOp
|
||||
{
|
||||
static constexpr const char* name = "TrinaryWithUnaryCombinedOp";
|
||||
|
||||
__host__ __device__ TrinaryWithUnaryCombinedOp()
|
||||
: binary_op0_(), binary_op1_(), unary_op0_(), unary_op1_(), unary_op2_()
|
||||
{
|
||||
|
||||
@@ -33,6 +33,8 @@ namespace element_wise {
|
||||
|
||||
struct AddReluAdd
|
||||
{
|
||||
static constexpr const char* name = "AddReluAdd";
|
||||
|
||||
template <typename Y, typename X0, typename X1, typename X2>
|
||||
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
|
||||
|
||||
@@ -102,6 +104,8 @@ struct AddReluAdd
|
||||
|
||||
struct AddHardswishAdd
|
||||
{
|
||||
static constexpr const char* name = "AddHardswishAdd";
|
||||
|
||||
template <typename Y, typename X0, typename X1, typename X2>
|
||||
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
|
||||
|
||||
@@ -134,6 +138,8 @@ struct AddHardswishAdd
|
||||
// E = C + D0 + D1
|
||||
struct AddAdd
|
||||
{
|
||||
static constexpr const char* name = "AddAdd";
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const
|
||||
{
|
||||
@@ -163,6 +169,8 @@ struct AddAdd
|
||||
// E = (C + D0) x D1
|
||||
struct AddMultiply
|
||||
{
|
||||
static constexpr const char* name = "AddMultiply";
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
@@ -199,6 +207,8 @@ struct AddMultiply
|
||||
// E = C x D0 + D1
|
||||
struct MultiplyAdd
|
||||
{
|
||||
static constexpr const char* name = "MultiplyAdd";
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
|
||||
@@ -251,6 +261,8 @@ struct MultiplyAdd
|
||||
|
||||
struct MultiplyMultiply
|
||||
{
|
||||
static constexpr const char* name = "MultiplyMultiply";
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
@@ -306,6 +318,8 @@ struct MultiplyMultiply
|
||||
|
||||
struct MultiplyAddFastGelu
|
||||
{
|
||||
static constexpr const char* name = "MultiplyAddFastGelu";
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
@@ -327,6 +341,8 @@ struct MultiplyAddFastGelu
|
||||
// E = FastGelu(C + D0 + D1)
|
||||
struct AddAddFastGelu
|
||||
{
|
||||
static constexpr const char* name = "AddAddFastGelu";
|
||||
|
||||
template <typename E, typename C, typename D0, typename D1>
|
||||
__host__ __device__ constexpr void
|
||||
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
|
||||
@@ -398,6 +414,7 @@ struct AddAddFastGelu
|
||||
// E = Relu(alpha1 * C + alpha2 * D0 + D1)
|
||||
struct ScaleAddScaleAddRelu
|
||||
{
|
||||
static constexpr const char* name = "ScaleAddScaleAddRelu";
|
||||
|
||||
ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
|
||||
: alpha1_(alpha1), alpha2_(alpha2)
|
||||
@@ -462,6 +479,8 @@ struct ScaleAddScaleAddRelu
|
||||
|
||||
struct Normalize
|
||||
{
|
||||
static constexpr const char* name = "Normalize";
|
||||
|
||||
// FIXME: is double absolutely necessary?
|
||||
Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
@@ -533,6 +552,8 @@ struct Normalize
|
||||
// The data type of mean and variance is used as AccDataType
|
||||
struct NormalizeInInfer
|
||||
{
|
||||
static constexpr const char* name = "NormalizeInInfer";
|
||||
|
||||
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
|
||||
|
||||
template <typename T1, typename T2, typename T3, typename T4>
|
||||
@@ -565,6 +586,8 @@ struct NormalizeInInfer
|
||||
// used by Conv+Bias+BatchNorm+Clamp inference
|
||||
struct BiasNormalizeInInferClamp
|
||||
{
|
||||
static constexpr const char* name = "BiasNormalizeInInferClamp";
|
||||
|
||||
BiasNormalizeInInferClamp(float floor = 0.f,
|
||||
float ceil = NumericLimits<float>::Max(),
|
||||
float epsilon = 1e-4)
|
||||
@@ -620,6 +643,8 @@ struct UnaryTypeConvert;
|
||||
template <>
|
||||
struct UnaryTypeConvert<float, ck::bhalf_t>
|
||||
{
|
||||
static constexpr const char* name = "UnaryTypeConvert";
|
||||
|
||||
__host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
|
||||
{
|
||||
y = ck::type_convert<float, ck::bhalf_t>(x);
|
||||
@@ -629,6 +654,8 @@ struct UnaryTypeConvert<float, ck::bhalf_t>
|
||||
template <>
|
||||
struct UnaryTypeConvert<ck::bhalf_t, float>
|
||||
{
|
||||
static constexpr const char* name = "UnaryTypeConvert";
|
||||
|
||||
__host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
|
||||
{
|
||||
y = ck::type_convert<ck::bhalf_t, float>(x);
|
||||
|
||||
@@ -24,6 +24,8 @@ namespace element_wise {
|
||||
template <typename Activation>
|
||||
struct Activation_Mul_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Activation_Mul_Clamp";
|
||||
|
||||
// Convolution + Activation (piecewise linear function)
|
||||
// If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy)
|
||||
// Z = Activation(Y) = Activation(W @ X)
|
||||
@@ -71,6 +73,8 @@ struct Activation_Mul_Clamp
|
||||
template <typename Activation>
|
||||
struct Mul_Activation_Mul_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Mul_Activation_Mul_Clamp";
|
||||
|
||||
// Convolution + Activation (non piecewise linear function)
|
||||
// Z = Activation(Y) = Activation(W @ X)
|
||||
// Sz * Qz = Activation(Sy * Qy)
|
||||
@@ -101,6 +105,8 @@ struct Mul_Activation_Mul_Clamp
|
||||
template <typename Activation>
|
||||
struct Activation_Mul2_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Activation_Mul2_Clamp";
|
||||
|
||||
Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
@@ -131,6 +137,8 @@ struct Activation_Mul2_Clamp
|
||||
template <typename Activation>
|
||||
struct Add_Activation_Mul_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Add_Activation_Mul_Clamp";
|
||||
|
||||
// Convolution + bias
|
||||
// Let Bias = B = Sw * Sx * Qb
|
||||
// Where Qb is int32
|
||||
@@ -175,6 +183,8 @@ struct Add_Activation_Mul_Clamp
|
||||
template <typename Activation>
|
||||
struct Add_Activation_Mul2_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Add_Activation_Mul2_Clamp";
|
||||
|
||||
Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
|
||||
|
||||
__host__ __device__ constexpr void
|
||||
@@ -206,6 +216,8 @@ struct Add_Activation_Mul2_Clamp
|
||||
template <typename Activation>
|
||||
struct Add_Mul_Activation_Mul_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Add_Mul_Activation_Mul_Clamp";
|
||||
|
||||
// Convolution + Activation (non piecewise linear function)
|
||||
// Z = Activation(Y) = Activation(W @ X + B)
|
||||
// Sz * Qz = Activation(Sy * Qy)
|
||||
@@ -250,6 +262,8 @@ struct Add_Mul_Activation_Mul_Clamp
|
||||
template <typename Activation>
|
||||
struct Add_Mul2_Activation_Mul_Clamp
|
||||
{
|
||||
static constexpr const char* name = "Add_Mul2_Activation_Mul_Clamp";
|
||||
|
||||
Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp)
|
||||
: scale_z_inv_(scale_z_inv), activationOp_(activationOp)
|
||||
{
|
||||
|
||||
@@ -157,6 +157,8 @@ namespace element_wise {
|
||||
|
||||
struct PassThroughPack8
|
||||
{
|
||||
static constexpr const char* name = "PassThroughPack8";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -265,6 +267,8 @@ struct PassThroughPack8
|
||||
|
||||
struct DequantPack8
|
||||
{
|
||||
static constexpr const char* name = "DequantPack8";
|
||||
|
||||
template <typename Y, typename X, typename Z>
|
||||
__host__ __device__ void operator()(Y& y, const X& x, const Z& z) const;
|
||||
|
||||
@@ -301,6 +305,8 @@ struct DequantPack8
|
||||
|
||||
struct PassThroughPack2
|
||||
{
|
||||
static constexpr const char* name = "PassThroughPack2";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -332,6 +338,8 @@ struct PassThroughPack2
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
static constexpr const char* name = "PassThrough";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -552,12 +560,12 @@ struct PassThrough
|
||||
{
|
||||
y = type_convert<bf8_t>(x);
|
||||
}
|
||||
|
||||
static constexpr const char* name = "PassThrough";
|
||||
};
|
||||
|
||||
struct UnaryConvert
|
||||
{
|
||||
static constexpr const char* name = "UnaryConvert";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
{
|
||||
@@ -567,6 +575,8 @@ struct UnaryConvert
|
||||
|
||||
struct ConvertBF16RTN
|
||||
{
|
||||
static constexpr const char* name = "ConvertBF16RTN";
|
||||
|
||||
// convert to bf16 using round to nearest (rtn)
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
@@ -584,6 +594,8 @@ struct ConvertBF16RTN
|
||||
|
||||
struct ConvertF8SR
|
||||
{
|
||||
static constexpr const char* name = "ConvertF8SR";
|
||||
|
||||
// convert to fp8 using stochastic rounding (SR)
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
@@ -602,6 +614,8 @@ struct ConvertF8SR
|
||||
|
||||
struct ConvertF8RNE
|
||||
{
|
||||
static constexpr const char* name = "ConvertF8RNE";
|
||||
|
||||
// convert to fp8 using rounding to nearest even
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const
|
||||
@@ -620,6 +634,8 @@ struct ConvertF8RNE
|
||||
|
||||
struct Scale
|
||||
{
|
||||
static constexpr const char* name = "Scale";
|
||||
|
||||
__host__ __device__ Scale(float scale = 1.f) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -665,6 +681,8 @@ struct Scale
|
||||
|
||||
struct ScaleAndResetNaNToMinusInfinity
|
||||
{
|
||||
static constexpr const char* name = "ScaleAndResetNaNToMinusInfinity";
|
||||
|
||||
__host__ __device__ ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -681,6 +699,8 @@ struct ScaleAndResetNaNToMinusInfinity
|
||||
|
||||
struct UnaryDivide
|
||||
{
|
||||
static constexpr const char* name = "UnaryDivide";
|
||||
|
||||
__host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
|
||||
|
||||
template <typename T>
|
||||
@@ -725,6 +745,8 @@ struct UnaryDivide
|
||||
|
||||
struct UnarySquare
|
||||
{
|
||||
static constexpr const char* name = "UnarySquare";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -741,6 +763,8 @@ struct UnarySquare
|
||||
|
||||
struct UnaryAbs
|
||||
{
|
||||
static constexpr const char* name = "UnaryAbs";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -771,6 +795,8 @@ struct UnaryAbs
|
||||
|
||||
struct UnarySqrt
|
||||
{
|
||||
static constexpr const char* name = "UnarySqrt";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -783,6 +809,8 @@ struct UnarySqrt
|
||||
|
||||
struct Clamp
|
||||
{
|
||||
static constexpr const char* name = "Clamp";
|
||||
|
||||
Clamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
|
||||
: floor_(floor), ceil_(ceil){};
|
||||
|
||||
@@ -856,6 +884,8 @@ struct Clamp
|
||||
|
||||
struct Relu
|
||||
{
|
||||
static constexpr const char* name = "Relu";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -892,6 +922,8 @@ struct Relu
|
||||
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
|
||||
struct FastGelu
|
||||
{
|
||||
static constexpr const char* name = "FastGelu";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -1007,6 +1039,8 @@ struct FastGelu
|
||||
// y = 0.5*x*(1+erf(x/sqrt(2)))
|
||||
struct Gelu
|
||||
{
|
||||
static constexpr const char* name = "Gelu";
|
||||
|
||||
template <typename Y, typename X>
|
||||
__host__ __device__ void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -1025,6 +1059,8 @@ struct Gelu
|
||||
|
||||
struct Sigmoid
|
||||
{
|
||||
static constexpr const char* name = "Sigmoid";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1049,6 +1085,8 @@ struct Sigmoid
|
||||
|
||||
struct Silu
|
||||
{
|
||||
static constexpr const char* name = "SiLU";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1062,6 +1100,8 @@ struct Silu
|
||||
|
||||
struct TanH
|
||||
{
|
||||
static constexpr const char* name = "TanH";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1085,6 +1125,8 @@ struct TanH
|
||||
|
||||
struct ACos
|
||||
{
|
||||
static constexpr const char* name = "ACos";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1099,6 +1141,8 @@ struct ACos
|
||||
|
||||
struct Neg
|
||||
{
|
||||
static constexpr const char* name = "Neg";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1113,6 +1157,8 @@ struct Neg
|
||||
|
||||
struct ATan
|
||||
{
|
||||
static constexpr const char* name = "ATan";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1127,6 +1173,8 @@ struct ATan
|
||||
|
||||
struct Sin
|
||||
{
|
||||
static constexpr const char* name = "Sin";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1141,6 +1189,8 @@ struct Sin
|
||||
|
||||
struct ASinH
|
||||
{
|
||||
static constexpr const char* name = "ASinH";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1155,6 +1205,8 @@ struct ASinH
|
||||
|
||||
struct Cos
|
||||
{
|
||||
static constexpr const char* name = "Cos";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1169,6 +1221,8 @@ struct Cos
|
||||
|
||||
struct ACosH
|
||||
{
|
||||
static constexpr const char* name = "ACosH";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1183,6 +1237,8 @@ struct ACosH
|
||||
|
||||
struct Tan
|
||||
{
|
||||
static constexpr const char* name = "Tan";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1197,6 +1253,8 @@ struct Tan
|
||||
|
||||
struct ATanH
|
||||
{
|
||||
static constexpr const char* name = "ATanH";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1211,6 +1269,8 @@ struct ATanH
|
||||
|
||||
struct SinH
|
||||
{
|
||||
static constexpr const char* name = "SinH";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1225,6 +1285,8 @@ struct SinH
|
||||
|
||||
struct Ceil
|
||||
{
|
||||
static constexpr const char* name = "Ceil";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1239,6 +1301,8 @@ struct Ceil
|
||||
|
||||
struct Exp
|
||||
{
|
||||
static constexpr const char* name = "Exp";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1253,6 +1317,8 @@ struct Exp
|
||||
|
||||
struct CosH
|
||||
{
|
||||
static constexpr const char* name = "CosH";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1267,6 +1333,8 @@ struct CosH
|
||||
|
||||
struct Floor
|
||||
{
|
||||
static constexpr const char* name = "Floor";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1281,6 +1349,8 @@ struct Floor
|
||||
|
||||
struct Log
|
||||
{
|
||||
static constexpr const char* name = "Log";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1295,6 +1365,8 @@ struct Log
|
||||
|
||||
struct ASin
|
||||
{
|
||||
static constexpr const char* name = "ASin";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1309,6 +1381,8 @@ struct ASin
|
||||
|
||||
struct Rcp
|
||||
{
|
||||
static constexpr const char* name = "Rcp";
|
||||
|
||||
template <typename T>
|
||||
__host__ __device__ void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1323,6 +1397,8 @@ struct Rcp
|
||||
|
||||
struct Swish
|
||||
{
|
||||
static constexpr const char* name = "Swish";
|
||||
|
||||
Swish(float beta = 1.0f) : beta_(beta) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -1352,6 +1428,8 @@ struct Swish
|
||||
|
||||
struct SoftRelu
|
||||
{
|
||||
static constexpr const char* name = "SoftRelu";
|
||||
|
||||
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1380,6 +1458,8 @@ struct SoftRelu
|
||||
|
||||
struct Power
|
||||
{
|
||||
static constexpr const char* name = "Power";
|
||||
|
||||
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
|
||||
: alpha_(alpha), beta_(beta), gamma_(gamma){};
|
||||
|
||||
@@ -1414,6 +1494,8 @@ struct Power
|
||||
|
||||
struct ClippedRelu
|
||||
{
|
||||
static constexpr const char* name = "ClippedRelu";
|
||||
|
||||
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1443,6 +1525,8 @@ struct ClippedRelu
|
||||
|
||||
struct LeakyRelu
|
||||
{
|
||||
static constexpr const char* name = "LeakyRelu";
|
||||
|
||||
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1470,6 +1554,8 @@ struct LeakyRelu
|
||||
|
||||
struct Elu
|
||||
{
|
||||
static constexpr const char* name = "Elu";
|
||||
|
||||
Elu(float alpha = 1.f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1497,6 +1583,8 @@ struct Elu
|
||||
|
||||
struct Logistic
|
||||
{
|
||||
static constexpr const char* name = "Logistic";
|
||||
|
||||
Logistic(float alpha = 1.f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1525,6 +1613,8 @@ struct Logistic
|
||||
|
||||
struct ConvInvscale
|
||||
{
|
||||
static constexpr const char* name = "ConvInvscale";
|
||||
|
||||
__host__ __device__ ConvInvscale(float scale_in = 1.f,
|
||||
float scale_wei = 1.f,
|
||||
float scale_out = 1.f)
|
||||
@@ -1548,6 +1638,8 @@ struct ConvInvscale
|
||||
|
||||
struct ConvScale
|
||||
{
|
||||
static constexpr const char* name = "ConvScale";
|
||||
|
||||
__host__ __device__ ConvScale(float scale_in = 1.f,
|
||||
float scale_wei = 1.f,
|
||||
float scale_out = 1.f)
|
||||
@@ -1571,6 +1663,8 @@ struct ConvScale
|
||||
|
||||
struct ConvScaleRelu
|
||||
{
|
||||
static constexpr const char* name = "ConvScaleRelu";
|
||||
|
||||
__host__ __device__ ConvScaleRelu(float scale_in = 1.f,
|
||||
float scale_wei = 1.f,
|
||||
float scale_out = 1.f)
|
||||
|
||||
Reference in New Issue
Block a user