Merge commit '1c17bae816edc44c32ee9d1a19d79d768fd1be13' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-28 06:16:21 +00:00
parent 56845d02b8
commit 5ed7d04d90
8 changed files with 632 additions and 72 deletions

View File

@@ -23,19 +23,12 @@
namespace ck_tile::reflect::detail {
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
template <typename Seq>
struct SequenceToArray;
template <ck::index_t... Is>
struct SequenceToArray<ck::Sequence<Is...>>
{
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
};
// Convert data types to string names
// Implementation detail for type name mapping
// This is the single source of truth for supported data types that
// returns an empty string to indicate an unsupported type.
namespace impl {
template <typename T>
consteval std::string_view type_name()
consteval std::string_view type_name_impl()
{
if constexpr(std::is_same_v<T, ck::half_t>)
return "fp16";
@@ -56,22 +49,38 @@ consteval std::string_view type_name()
else if constexpr(std::is_same_v<T, ck::bf8_t>)
return "bf8";
else
static_assert(false, "unknown_type");
return std::string_view{}; // Return empty for supported types
}
} // namespace impl
// Convert data types to string names
// Fails at compile time for unsupported types
template <typename T>
consteval std::string_view type_name()
{
constexpr auto name = impl::type_name_impl<T>();
static_assert(!name.empty(), "Unsupported data type");
return name;
}
// Convert layout types to string names
// Concept that checks if a type is a valid data type
// Uses the impl directly to avoid triggering static_assert during concept evaluation
template <typename T>
concept IsDataType = !impl::type_name_impl<T>().empty();
// Concept that checks valid layout types
template <typename T>
concept IsLayoutType = (std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> ||
std::is_base_of_v<ck::tensor_layout::BaseTensorLayout, T>) &&
requires {
{ T::name } -> std::convertible_to<std::string_view>;
};
// Convert layout types to string names
template <IsLayoutType T>
constexpr std::string_view layout_name()
{
if constexpr((std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> ||
std::is_base_of_v<ck::tensor_layout::BaseTensorLayout, T>) &&
requires {
{ T::name } -> std::convertible_to<std::string_view>;
})
return T::name;
else
static_assert(false,
"Layout type must derive from BaseTensorLayout and have name attribute");
return T::name;
}
// Convert element-wise operation types to string names
@@ -90,64 +99,64 @@ constexpr std::string_view elementwise_op_name()
constexpr std::string_view
conv_fwd_spec_name(ck::tensor_operation::device::ConvolutionForwardSpecialization spec)
{
using ck::tensor_operation::device::ConvolutionForwardSpecialization;
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
switch(spec)
{
case ConvolutionForwardSpecialization::Default: return "Default";
case ConvolutionForwardSpecialization::Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case ConvolutionForwardSpecialization::Filter1x1Pad0: return "Filter1x1Pad0";
case ConvolutionForwardSpecialization::Filter3x3: return "Filter3x3";
case ConvolutionForwardSpecialization::OddC: return "OddC";
case Default: return "Default";
case Filter1x1Stride1Pad0: return "Filter1x1Stride1Pad0";
case Filter1x1Pad0: return "Filter1x1Pad0";
case Filter3x3: return "Filter3x3";
case OddC: return "OddC";
}
}
// Convert GemmSpecialization enum to string
constexpr std::string_view gemm_spec_name(ck::tensor_operation::device::GemmSpecialization spec)
{
using ck::tensor_operation::device::GemmSpecialization;
using enum ck::tensor_operation::device::GemmSpecialization;
switch(spec)
{
case GemmSpecialization::Default: return "Default";
case GemmSpecialization::MPadding: return "MPadding";
case GemmSpecialization::NPadding: return "NPadding";
case GemmSpecialization::KPadding: return "KPadding";
case GemmSpecialization::MNPadding: return "MNPadding";
case GemmSpecialization::MKPadding: return "MKPadding";
case GemmSpecialization::NKPadding: return "NKPadding";
case GemmSpecialization::MNKPadding: return "MNKPadding";
case GemmSpecialization::OPadding: return "OPadding";
case GemmSpecialization::MOPadding: return "MOPadding";
case GemmSpecialization::NOPadding: return "NOPadding";
case GemmSpecialization::KOPadding: return "KOPadding";
case GemmSpecialization::MNOPadding: return "MNOPadding";
case GemmSpecialization::MKOPadding: return "MKOPadding";
case GemmSpecialization::NKOPadding: return "NKOPadding";
case GemmSpecialization::MNKOPadding: return "MNKOPadding";
case Default: return "Default";
case MPadding: return "MPadding";
case NPadding: return "NPadding";
case KPadding: return "KPadding";
case MNPadding: return "MNPadding";
case MKPadding: return "MKPadding";
case NKPadding: return "NKPadding";
case MNKPadding: return "MNKPadding";
case OPadding: return "OPadding";
case MOPadding: return "MOPadding";
case NOPadding: return "NOPadding";
case KOPadding: return "KOPadding";
case MNOPadding: return "MNOPadding";
case MKOPadding: return "MKOPadding";
case NKOPadding: return "NKOPadding";
case MNKOPadding: return "MNKOPadding";
}
}
// Convert BlockGemmPipelineScheduler enum to string
constexpr std::string_view pipeline_scheduler_name(ck::BlockGemmPipelineScheduler sched)
{
using ck::BlockGemmPipelineScheduler;
using enum ck::BlockGemmPipelineScheduler;
switch(sched)
{
case BlockGemmPipelineScheduler::Intrawave: return "Intrawave";
case BlockGemmPipelineScheduler::Interwave: return "Interwave";
case Intrawave: return "Intrawave";
case Interwave: return "Interwave";
}
}
// Convert BlockGemmPipelineVersion enum to string
constexpr std::string_view pipeline_version_name(ck::BlockGemmPipelineVersion ver)
{
using ck::BlockGemmPipelineVersion;
using enum ck::BlockGemmPipelineVersion;
switch(ver)
{
case BlockGemmPipelineVersion::v1: return "v1";
case BlockGemmPipelineVersion::v2: return "v2";
case BlockGemmPipelineVersion::v3: return "v3";
case BlockGemmPipelineVersion::v4: return "v4";
case BlockGemmPipelineVersion::v5: return "v5";
case v1: return "v1";
case v2: return "v2";
case v3: return "v3";
case v4: return "v4";
case v5: return "v5";
}
}
@@ -167,12 +176,138 @@ inline std::string array_to_string(const std::array<T, N>& arr)
return oss.str();
}
// Handle ck::Tuple (empty tuple for DsLayout/DsDataType)
template <typename T>
constexpr std::string_view tuple_name()
// Metaprogramming helper to convert ck::Sequence to constexpr std::array
template <typename Seq>
struct SequenceToArray;
template <ck::index_t... Is>
struct SequenceToArray<ck::Sequence<Is...>>
{
// For now, just check if it's an empty tuple
return "EmptyTuple";
static constexpr std::array<int, sizeof...(Is)> value = {static_cast<int>(Is)...};
};
namespace detail {
// Generic helper to build list-like strings (Tuple, Seq, etc.)
//
// Example output: "Seq(1,2,3)"
//
// prefix: The list-like container name (e.g. "Tuple" or "Seq")
// converter_fn: A callable that converts each element to a string representation
// For types: converter_fn should be a template lambda like []<typename U>() { return
// type_name<U>(); } For values: converter_fn should be a regular lambda like [](auto value) {
// return std::to_string(value); }
template <typename ConverterFn, typename... Elements>
constexpr std::string build_list_string(std::string_view prefix, const ConverterFn& converter_fn)
{
if constexpr(sizeof...(Elements) == 0)
{
return std::string(prefix) + "()";
}
else
{
std::string result = std::string(prefix) + "(";
std::size_t index = 0;
((result +=
(index++ > 0 ? "," : "") + std::string(converter_fn.template operator()<Elements>())),
...);
result += ")";
return result;
}
}
// Overload for value-based lists (sequences)
template <typename ConverterFn, auto... Values>
constexpr std::string build_list_string_values(std::string_view prefix,
const ConverterFn& converter_fn)
{
if constexpr(sizeof...(Values) == 0)
{
return std::string(prefix) + "()";
}
else
{
std::string result = std::string(prefix) + "(";
std::size_t index = 0;
((result += (index++ > 0 ? "," : "") + converter_fn(Values)), ...);
result += ")";
return result;
}
}
} // namespace detail
// Convert ck::Sequence to string representation
// Converts a ck::Sequence type to a string in the format "Seq(v1,v2,...,vn)"
// where each value is converted using std::to_string.
//
// Template parameter:
// T: Must be a ck::Sequence<...> type
//
// Constraints:
// - Sequence elements must support std::to_string (typically ck::index_t)
//
// Examples:
// sequence_name<ck::Sequence<>>() returns "Seq()"
// sequence_name<ck::Sequence<42>>() returns "Seq(42)"
// sequence_name<ck::Sequence<1,2,3>>() returns "Seq(1,2,3)"
// sequence_name<ck::Sequence<256,128,64>>() returns "Seq(256,128,64)"
template <typename T>
requires requires { []<ck::index_t... Is>(ck::Sequence<Is...>*) {}(static_cast<T*>(nullptr)); }
constexpr std::string sequence_name()
{
return []<ck::index_t... Is>(ck::Sequence<Is...>*) constexpr {
auto to_string_fn = [](auto value) { return std::to_string(value); };
return detail::build_list_string_values<decltype(to_string_fn), Is...>("Seq", to_string_fn);
}(static_cast<T*>(nullptr));
}
// Convert ck::Tuple to string representation
// Converts a ck::Tuple type to a string in the format "Tuple(e1,e2,...,en)"
// where each element is converted based on its type (layout names or data type names).
//
// Template parameter:
// T: Must be a ck::Tuple<...> type
//
// Constraints:
// - Empty tuples are supported and return "EmptyTuple"
// - All tuple elements must be homogeneous: either all layouts (IsLayoutType) or all data types
// (IsDataType)
// - Mixed layouts and data types in the same tuple will cause a compile-time error
//
// Examples:
// tuple_name<ck::Tuple<>>() returns "EmptyTuple"
// tuple_name<ck::Tuple<ck::tensor_layout::gemm::RowMajor>>() returns "Tuple(RowMajor)"
// tuple_name<ck::Tuple<NCHW,NHWC>>() returns "Tuple(NCHW,NHWC)"
// tuple_name<ck::Tuple<ck::half_t>>() returns "Tuple(fp16)"
// tuple_name<ck::Tuple<ck::half_t,float,double>>() returns "Tuple(fp16,fp32,fp64)"
template <typename T>
requires requires { []<typename... Ts>(ck::Tuple<Ts...>*) {}(static_cast<T*>(nullptr)); }
constexpr std::string tuple_name()
{
return []<typename... Ts>(ck::Tuple<Ts...>*) constexpr {
if constexpr(sizeof...(Ts) == 0)
{
return std::string("EmptyTuple");
}
else if constexpr((IsLayoutType<Ts> && ...))
{
// Lambda wrapper for layout_name
auto layout_name_fn = []<typename U>() { return layout_name<U>(); };
return detail::build_list_string<decltype(layout_name_fn), Ts...>("Tuple",
layout_name_fn);
}
else if constexpr((IsDataType<Ts> && ...))
{
// Lambda wrapper for type_name
auto type_name_fn = []<typename U>() { return type_name<U>(); };
return detail::build_list_string<decltype(type_name_fn), Ts...>("Tuple", type_name_fn);
}
else
{
static_assert((IsLayoutType<Ts> && ...) || (IsDataType<Ts> && ...),
"Tuple elements must be all layouts or all data types, not mixed");
return std::string{}; // unreachable
}
}(static_cast<T*>(nullptr));
}
} // namespace ck_tile::reflect::detail

View File

@@ -16,20 +16,13 @@ function(add_ck_builder_test test_name)
target_link_libraries(${test_name} PRIVATE GTest::gtest_main GTest::gmock)
endfunction()
# The test_conv_builder target has all the unit tests (each test should run < 10 ms)
add_ck_builder_test(test_conv_builder
test_conv_builder.cpp
test_instance_traits.cpp
test_instance_traits_util.cpp
testing_utils.cpp)
# Testing the virtual GetInstanceString methods requires kernel compilation.
add_ck_builder_test(test_get_instance_string
test_get_instance_string.cpp)
add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp)
add_ck_builder_test(test_ckb_build_fwd_instances
conv/test_ckb_conv_fwd_2d_bf16.cpp
conv/test_ckb_conv_fwd_2d_fp16.cpp
conv/test_ckb_conv_fwd_2d_fp32.cpp
conv/test_ckb_conv_fwd_3d_bf16.cpp
conv/test_ckb_conv_fwd_3d_fp16.cpp
conv/test_ckb_conv_fwd_3d_fp32.cpp)

View File

@@ -0,0 +1,263 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <ck_tile/builder/reflect/instance_traits_util.hpp>
#include <ck/utility/data_type.hpp>
#include <ck/utility/sequence.hpp>
#include <ck/utility/blkgemmpipe_scheduler.hpp>
#include <ck/tensor_operation/gpu/device/tensor_layout.hpp>
#include <ck/tensor_operation/gpu/element/element_wise_operation.hpp>
#include <ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp>
#include <ck/tensor_operation/gpu/device/gemm_specialization.hpp>
namespace ck_tile::reflect::detail {
namespace {
using ::testing::ElementsAre;
using ::testing::IsEmpty;
TEST(InstanceTraitsUtil, SequenceToArrayReturnsEmptyArrayForEmptySequence)
{
EXPECT_THAT(SequenceToArray<ck::Sequence<>>::value, IsEmpty());
}
TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithSingleElement)
{
EXPECT_THAT(SequenceToArray<ck::Sequence<42>>::value, ElementsAre(42));
}
TEST(InstanceTraitsUtil, SequenceToArrayReturnsArrayWithMultipleElements)
{
EXPECT_THAT((SequenceToArray<ck::Sequence<1, 2, 3, 4, 5>>::value), ElementsAre(1, 2, 3, 4, 5));
}
TEST(InstanceTraitsUtil, TypeNameReturnsCorrectStrings)
{
EXPECT_THAT((std::vector<std::string_view>{type_name<ck::half_t>(),
type_name<float>(),
type_name<double>(),
type_name<int8_t>(),
type_name<int32_t>(),
type_name<ck::bhalf_t>(),
type_name<ck::f8_t>(),
type_name<ck::bf8_t>()}),
ElementsAre("fp16", "fp32", "fp64", "s8", "s32", "bf16", "fp8", "bf8"));
}
TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForGemmLayouts)
{
namespace gemm = ck::tensor_layout::gemm;
EXPECT_THAT((std::vector<std::string_view>{layout_name<gemm::RowMajor>(),
layout_name<gemm::ColumnMajor>(),
layout_name<gemm::MFMA>()}),
ElementsAre("RowMajor", "ColumnMajor", "MFMA"));
}
TEST(InstanceTraitsUtil, LayoutNameReturnsCorrectStringsForConvLayouts)
{
namespace conv = ck::tensor_layout::convolution;
EXPECT_THAT((std::vector<std::string_view>{
// Input tensor layouts
// TODO(deprecated): Remove non-grouped layouts once instances are removed.
layout_name<conv::NCHW>(),
layout_name<conv::NHWC>(),
layout_name<conv::NCDHW>(),
layout_name<conv::NDHWC>(),
// Grouped input layouts
layout_name<conv::GNCHW>(),
layout_name<conv::GNHWC>(),
// Weight tensor layouts
layout_name<conv::KCYX>(),
layout_name<conv::KYXC>(),
layout_name<conv::GKCYX>(),
layout_name<conv::GKYXC>(),
// Output tensor layouts
layout_name<conv::NKHW>(),
layout_name<conv::NHWK>(),
layout_name<conv::GNKHW>(),
layout_name<conv::GNHWK>(),
// Strided layouts
// TODO(deprecated): Remove strided layouts once instances are removed.
layout_name<conv::G_NHW_C>(),
layout_name<conv::G_K_YX_C>(),
layout_name<conv::G_NHW_K>(),
// Bias layouts
layout_name<conv::G_C>(),
layout_name<conv::G_K>()}),
ElementsAre("NCHW",
"NHWC",
"NCDHW",
"NDHWC",
"GNCHW",
"GNHWC",
"KCYX",
"KYXC",
"GKCYX",
"GKYXC",
"NKHW",
"NHWK",
"GNKHW",
"GNHWK",
"G_NHW_C",
"G_K_YX_C",
"G_NHW_K",
"G_C",
"G_K"));
}
TEST(InstanceTraitsUtil, ElementwiseOpNameReturnsCorrectStrings)
{
namespace element_wise = ck::tensor_operation::element_wise;
EXPECT_THAT((std::vector<std::string_view>{
elementwise_op_name<element_wise::PassThrough>(),
elementwise_op_name<element_wise::Scale>(),
elementwise_op_name<element_wise::Bilinear>(),
elementwise_op_name<element_wise::Add>(),
elementwise_op_name<element_wise::AddRelu>(),
elementwise_op_name<element_wise::Relu>(),
elementwise_op_name<element_wise::BiasNormalizeInInferClamp>(),
elementwise_op_name<element_wise::Clamp>(),
elementwise_op_name<element_wise::AddClamp>()}),
ElementsAre("PassThrough",
"Scale",
"Bilinear",
"Add",
"AddRelu",
"Relu",
"BiasNormalizeInInferClamp",
"Clamp",
"AddClamp"));
}
TEST(InstanceTraitsUtil, ConvFwdSpecNameReturnsCorrectStrings)
{
using enum ck::tensor_operation::device::ConvolutionForwardSpecialization;
EXPECT_THAT(
(std::vector<std::string_view>{conv_fwd_spec_name(Default),
conv_fwd_spec_name(Filter1x1Stride1Pad0),
conv_fwd_spec_name(Filter1x1Pad0),
conv_fwd_spec_name(Filter3x3),
conv_fwd_spec_name(OddC)}),
ElementsAre("Default", "Filter1x1Stride1Pad0", "Filter1x1Pad0", "Filter3x3", "OddC"));
}
TEST(InstanceTraitsUtil, GemmSpecNameReturnsCorrectStrings)
{
using enum ck::tensor_operation::device::GemmSpecialization;
EXPECT_THAT((std::vector<std::string_view>{gemm_spec_name(Default),
gemm_spec_name(MPadding),
gemm_spec_name(NPadding),
gemm_spec_name(KPadding),
gemm_spec_name(MNPadding),
gemm_spec_name(MKPadding),
gemm_spec_name(NKPadding),
gemm_spec_name(MNKPadding),
gemm_spec_name(OPadding),
gemm_spec_name(MOPadding),
gemm_spec_name(NOPadding),
gemm_spec_name(KOPadding),
gemm_spec_name(MNOPadding),
gemm_spec_name(MKOPadding),
gemm_spec_name(NKOPadding),
gemm_spec_name(MNKOPadding)}),
ElementsAre("Default",
"MPadding",
"NPadding",
"KPadding",
"MNPadding",
"MKPadding",
"NKPadding",
"MNKPadding",
"OPadding",
"MOPadding",
"NOPadding",
"KOPadding",
"MNOPadding",
"MKOPadding",
"NKOPadding",
"MNKOPadding"));
}
TEST(InstanceTraitsUtil, PipelineSchedulerNameReturnsCorrectStrings)
{
using enum ck::BlockGemmPipelineScheduler;
EXPECT_THAT((std::vector<std::string_view>{pipeline_scheduler_name(Intrawave),
pipeline_scheduler_name(Interwave)}),
ElementsAre("Intrawave", "Interwave"));
}
TEST(InstanceTraitsUtil, PipelineVersionNameReturnsCorrectStrings)
{
using enum ck::BlockGemmPipelineVersion;
EXPECT_THAT((std::vector<std::string_view>{pipeline_version_name(v1),
pipeline_version_name(v2),
pipeline_version_name(v3),
pipeline_version_name(v4),
pipeline_version_name(v5)}),
ElementsAre("v1", "v2", "v3", "v4", "v5"));
}
TEST(InstanceTraitsUtil, TupleNameReturnsEmptyTupleForEmptyTuple)
{
EXPECT_EQ(tuple_name<ck::Tuple<>>(), "EmptyTuple");
}
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleLayout)
{
EXPECT_EQ(tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW>>(), "Tuple(NCHW)");
}
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoLayouts)
{
EXPECT_EQ((tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW,
ck::tensor_layout::convolution::NHWC>>()),
"Tuple(NCHW,NHWC)");
}
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeLayouts)
{
EXPECT_EQ((tuple_name<ck::Tuple<ck::tensor_layout::convolution::NCHW,
ck::tensor_layout::convolution::NHWC,
ck::tensor_layout::convolution::NKHW>>()),
"Tuple(NCHW,NHWC,NKHW)");
}
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForSingleDataType)
{
EXPECT_EQ(tuple_name<ck::Tuple<ck::half_t>>(), "Tuple(fp16)");
}
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForTwoDataTypes)
{
EXPECT_EQ((tuple_name<ck::Tuple<ck::half_t, float>>()), "Tuple(fp16,fp32)");
}
TEST(InstanceTraitsUtil, TupleNameReturnsTupleStringForThreeDataTypes)
{
EXPECT_EQ((tuple_name<ck::Tuple<ck::half_t, float, double>>()), "Tuple(fp16,fp32,fp64)");
}
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForEmptySequence)
{
EXPECT_EQ(sequence_name<ck::Sequence<>>(), "Seq()");
}
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForSingleValueSequence)
{
EXPECT_EQ(sequence_name<ck::Sequence<42>>(), "Seq(42)");
}
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForTwoValueSequence)
{
EXPECT_EQ((sequence_name<ck::Sequence<1, 2>>()), "Seq(1,2)");
}
TEST(InstanceTraitsUtil, SequenceNameReturnsSeqStringForMultipleValueSequence)
{
EXPECT_EQ((sequence_name<ck::Sequence<256, 128, 64, 32, 16>>()), "Seq(256,128,64,32,16)");
}
} // namespace
} // namespace ck_tile::reflect::detail

View File

@@ -12,6 +12,8 @@ namespace element_wise {
struct Add
{
static constexpr const char* name = "Add";
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
@@ -94,6 +96,8 @@ struct Add
struct Max
{
static constexpr const char* name = "Max";
template <typename Y, typename X0, typename X1>
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
{
@@ -105,6 +109,8 @@ struct Max
struct Min
{
static constexpr const char* name = "Min";
template <typename Y, typename X0, typename X1>
__host__ __device__ void operator()(Y& y, const X0& x0, const X1& x1) const
{
@@ -116,6 +122,8 @@ struct Min
struct Multiply
{
static constexpr const char* name = "Multiply";
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
@@ -208,6 +216,8 @@ struct Multiply
struct ScaleAdd
{
static constexpr const char* name = "ScaleAdd";
__host__ __device__ ScaleAdd(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X0, typename X1>
@@ -235,6 +245,8 @@ struct ScaleAdd
struct Subtract
{
static constexpr const char* name = "Subtract";
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
@@ -279,6 +291,8 @@ struct Subtract
struct Bilinear
{
static constexpr const char* name = "Bilinear";
Bilinear(float alpha = 1.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename Y, typename X0, typename X1>
@@ -353,6 +367,8 @@ struct Bilinear
struct AddClamp
{
static constexpr const char* name = "AddClamp";
AddClamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
: floor_(floor), ceil_(ceil){};
@@ -442,6 +458,8 @@ struct AddClamp
struct AddRelu
{
static constexpr const char* name = "AddRelu";
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
@@ -523,6 +541,8 @@ struct AddRelu
struct AddHardswish
{
static constexpr const char* name = "AddHardswish";
template <typename T>
__host__ __device__ constexpr void operator()(T& y, const T& x0, const T& x1) const;
@@ -560,6 +580,8 @@ struct AddHardswish
// E = FastGelu(C + D)
struct AddFastGelu
{
static constexpr const char* name = "AddFastGelu";
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
@@ -625,6 +647,8 @@ struct AddFastGelu
// E = MultiplyFastGelu(C + D)
struct MultiplyFastGelu
{
static constexpr const char* name = "MultiplyFastGelu";
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
@@ -690,6 +714,8 @@ struct MultiplyFastGelu
// E = Silu(C + D)
struct AddSilu
{
static constexpr const char* name = "AddSilu";
template <typename E, typename C, typename D>
__host__ __device__ constexpr void operator()(E& e, const C& c, const D& d) const;
@@ -740,6 +766,8 @@ struct AddSilu
struct ConvScaleAdd
{
static constexpr const char* name = "ConvScaleAdd";
__host__ __device__ ConvScaleAdd(float scale_in = 1.f,
float scale_wei = 1.f,
float scale_out = 1.f)

View File

@@ -13,6 +13,8 @@ namespace element_wise {
template <typename... UnaryOpsSet>
struct UnaryCombinedOp
{
static constexpr const char* name = "UnaryCombinedOp";
__host__ __device__ UnaryCombinedOp() : unary_ops_() {}
__host__ __device__ UnaryCombinedOp(UnaryOpsSet... unary_ops) : unary_ops_(unary_ops...) {}
@@ -33,6 +35,8 @@ struct UnaryCombinedOp
template <typename BinaryOp, typename UnaryOp0, typename UnaryOp1>
struct BinaryWithUnaryCombinedOp
{
static constexpr const char* name = "BinaryWithUnaryCombinedOp";
__host__ __device__ BinaryWithUnaryCombinedOp() : binary_op_(), unary_op0_(), unary_op1_() {}
__host__ __device__ BinaryWithUnaryCombinedOp(BinaryOp binary_op,
@@ -66,6 +70,8 @@ template <typename BinaryOp0,
typename UnaryOp2>
struct TrinaryWithUnaryCombinedOp
{
static constexpr const char* name = "TrinaryWithUnaryCombinedOp";
__host__ __device__ TrinaryWithUnaryCombinedOp()
: binary_op0_(), binary_op1_(), unary_op0_(), unary_op1_(), unary_op2_()
{

View File

@@ -33,6 +33,8 @@ namespace element_wise {
struct AddReluAdd
{
static constexpr const char* name = "AddReluAdd";
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
@@ -102,6 +104,8 @@ struct AddReluAdd
struct AddHardswishAdd
{
static constexpr const char* name = "AddHardswishAdd";
template <typename Y, typename X0, typename X1, typename X2>
__host__ __device__ constexpr void operator()(Y&, const X0&, const X1&, const X2&) const;
@@ -134,6 +138,8 @@ struct AddHardswishAdd
// E = C + D0 + D1
struct AddAdd
{
static constexpr const char* name = "AddAdd";
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const
{
@@ -163,6 +169,8 @@ struct AddAdd
// E = (C + D0) x D1
struct AddMultiply
{
static constexpr const char* name = "AddMultiply";
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
@@ -199,6 +207,8 @@ struct AddMultiply
// E = C x D0 + D1
struct MultiplyAdd
{
static constexpr const char* name = "MultiplyAdd";
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ void operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
@@ -251,6 +261,8 @@ struct MultiplyAdd
struct MultiplyMultiply
{
static constexpr const char* name = "MultiplyMultiply";
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
@@ -306,6 +318,8 @@ struct MultiplyMultiply
struct MultiplyAddFastGelu
{
static constexpr const char* name = "MultiplyAddFastGelu";
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
@@ -327,6 +341,8 @@ struct MultiplyAddFastGelu
// E = FastGelu(C + D0 + D1)
struct AddAddFastGelu
{
static constexpr const char* name = "AddAddFastGelu";
template <typename E, typename C, typename D0, typename D1>
__host__ __device__ constexpr void
operator()(E& e, const C& c, const D0& d0, const D1& d1) const;
@@ -398,6 +414,7 @@ struct AddAddFastGelu
// E = Relu(alpha1 * C + alpha2 * D0 + D1)
struct ScaleAddScaleAddRelu
{
static constexpr const char* name = "ScaleAddScaleAddRelu";
ScaleAddScaleAddRelu(const float alpha1 = 1.f, const float alpha2 = 1.f)
: alpha1_(alpha1), alpha2_(alpha2)
@@ -462,6 +479,8 @@ struct ScaleAddScaleAddRelu
struct Normalize
{
static constexpr const char* name = "Normalize";
// FIXME: is double absolutely necessary?
Normalize(double epsilon = 1e-4) : epsilon_(epsilon) {}
@@ -533,6 +552,8 @@ struct Normalize
// The data type of mean and variance is used as AccDataType
struct NormalizeInInfer
{
static constexpr const char* name = "NormalizeInInfer";
NormalizeInInfer(double epsilon = 1e-4) : epsilon_(epsilon) {}
template <typename T1, typename T2, typename T3, typename T4>
@@ -565,6 +586,8 @@ struct NormalizeInInfer
// used by Conv+Bias+BatchNorm+Clamp inference
struct BiasNormalizeInInferClamp
{
static constexpr const char* name = "BiasNormalizeInInferClamp";
BiasNormalizeInInferClamp(float floor = 0.f,
float ceil = NumericLimits<float>::Max(),
float epsilon = 1e-4)
@@ -620,6 +643,8 @@ struct UnaryTypeConvert;
template <>
struct UnaryTypeConvert<float, ck::bhalf_t>
{
static constexpr const char* name = "UnaryTypeConvert";
__host__ __device__ void operator()(float& y, ck::bhalf_t& x) const
{
y = ck::type_convert<float, ck::bhalf_t>(x);
@@ -629,6 +654,8 @@ struct UnaryTypeConvert<float, ck::bhalf_t>
template <>
struct UnaryTypeConvert<ck::bhalf_t, float>
{
static constexpr const char* name = "UnaryTypeConvert";
__host__ __device__ void operator()(ck::bhalf_t& y, float& x) const
{
y = ck::type_convert<ck::bhalf_t, float>(x);

View File

@@ -24,6 +24,8 @@ namespace element_wise {
template <typename Activation>
struct Activation_Mul_Clamp
{
static constexpr const char* name = "Activation_Mul_Clamp";
// Convolution + Activation (piecewise linear function)
// If an activation is piecewise linear function, then Activation(Sy * Qy) = Sy * Activation(Qy)
// Z = Activation(Y) = Activation(W @ X)
@@ -71,6 +73,8 @@ struct Activation_Mul_Clamp
template <typename Activation>
struct Mul_Activation_Mul_Clamp
{
static constexpr const char* name = "Mul_Activation_Mul_Clamp";
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X)
// Sz * Qz = Activation(Sy * Qy)
@@ -101,6 +105,8 @@ struct Mul_Activation_Mul_Clamp
template <typename Activation>
struct Activation_Mul2_Clamp
{
static constexpr const char* name = "Activation_Mul2_Clamp";
Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
@@ -131,6 +137,8 @@ struct Activation_Mul2_Clamp
template <typename Activation>
struct Add_Activation_Mul_Clamp
{
static constexpr const char* name = "Add_Activation_Mul_Clamp";
// Convolution + bias
// Let Bias = B = Sw * Sx * Qb
// Where Qb is int32
@@ -175,6 +183,8 @@ struct Add_Activation_Mul_Clamp
template <typename Activation>
struct Add_Activation_Mul2_Clamp
{
static constexpr const char* name = "Add_Activation_Mul2_Clamp";
Add_Activation_Mul2_Clamp(Activation activationOp) : activationOp_(activationOp) {}
__host__ __device__ constexpr void
@@ -206,6 +216,8 @@ struct Add_Activation_Mul2_Clamp
template <typename Activation>
struct Add_Mul_Activation_Mul_Clamp
{
static constexpr const char* name = "Add_Mul_Activation_Mul_Clamp";
// Convolution + Activation (non piecewise linear function)
// Z = Activation(Y) = Activation(W @ X + B)
// Sz * Qz = Activation(Sy * Qy)
@@ -250,6 +262,8 @@ struct Add_Mul_Activation_Mul_Clamp
template <typename Activation>
struct Add_Mul2_Activation_Mul_Clamp
{
static constexpr const char* name = "Add_Mul2_Activation_Mul_Clamp";
Add_Mul2_Activation_Mul_Clamp(float scale_z_inv, Activation activationOp)
: scale_z_inv_(scale_z_inv), activationOp_(activationOp)
{

View File

@@ -157,6 +157,8 @@ namespace element_wise {
struct PassThroughPack8
{
static constexpr const char* name = "PassThroughPack8";
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
@@ -265,6 +267,8 @@ struct PassThroughPack8
struct DequantPack8
{
static constexpr const char* name = "DequantPack8";
template <typename Y, typename X, typename Z>
__host__ __device__ void operator()(Y& y, const X& x, const Z& z) const;
@@ -301,6 +305,8 @@ struct DequantPack8
struct PassThroughPack2
{
static constexpr const char* name = "PassThroughPack2";
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
@@ -332,6 +338,8 @@ struct PassThroughPack2
struct PassThrough
{
static constexpr const char* name = "PassThrough";
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
@@ -552,12 +560,12 @@ struct PassThrough
{
y = type_convert<bf8_t>(x);
}
static constexpr const char* name = "PassThrough";
};
struct UnaryConvert
{
static constexpr const char* name = "UnaryConvert";
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
{
@@ -567,6 +575,8 @@ struct UnaryConvert
struct ConvertBF16RTN
{
static constexpr const char* name = "ConvertBF16RTN";
// convert to bf16 using round to nearest (rtn)
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
@@ -584,6 +594,8 @@ struct ConvertBF16RTN
struct ConvertF8SR
{
static constexpr const char* name = "ConvertF8SR";
// convert to fp8 using stochastic rounding (SR)
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
@@ -602,6 +614,8 @@ struct ConvertF8SR
struct ConvertF8RNE
{
static constexpr const char* name = "ConvertF8RNE";
// convert to fp8 using rounding to nearest even
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const
@@ -620,6 +634,8 @@ struct ConvertF8RNE
struct Scale
{
static constexpr const char* name = "Scale";
__host__ __device__ Scale(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X>
@@ -665,6 +681,8 @@ struct Scale
struct ScaleAndResetNaNToMinusInfinity
{
static constexpr const char* name = "ScaleAndResetNaNToMinusInfinity";
__host__ __device__ ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
template <typename Y, typename X>
@@ -681,6 +699,8 @@ struct ScaleAndResetNaNToMinusInfinity
struct UnaryDivide
{
static constexpr const char* name = "UnaryDivide";
__host__ __device__ UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
template <typename T>
@@ -725,6 +745,8 @@ struct UnaryDivide
struct UnarySquare
{
static constexpr const char* name = "UnarySquare";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -741,6 +763,8 @@ struct UnarySquare
struct UnaryAbs
{
static constexpr const char* name = "UnaryAbs";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -771,6 +795,8 @@ struct UnaryAbs
struct UnarySqrt
{
static constexpr const char* name = "UnarySqrt";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -783,6 +809,8 @@ struct UnarySqrt
struct Clamp
{
static constexpr const char* name = "Clamp";
Clamp(float floor = 0.f, float ceil = NumericLimits<float>::Max())
: floor_(floor), ceil_(ceil){};
@@ -856,6 +884,8 @@ struct Clamp
struct Relu
{
static constexpr const char* name = "Relu";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -892,6 +922,8 @@ struct Relu
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct FastGelu
{
static constexpr const char* name = "FastGelu";
template <typename Y, typename X>
__host__ void operator()(Y& y, const X& x) const;
@@ -1007,6 +1039,8 @@ struct FastGelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
{
static constexpr const char* name = "Gelu";
template <typename Y, typename X>
__host__ __device__ void operator()(Y& y, const X& x) const;
@@ -1025,6 +1059,8 @@ struct Gelu
struct Sigmoid
{
static constexpr const char* name = "Sigmoid";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1049,6 +1085,8 @@ struct Sigmoid
struct Silu
{
static constexpr const char* name = "SiLU";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1062,6 +1100,8 @@ struct Silu
struct TanH
{
static constexpr const char* name = "TanH";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1085,6 +1125,8 @@ struct TanH
struct ACos
{
static constexpr const char* name = "ACos";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1099,6 +1141,8 @@ struct ACos
struct Neg
{
static constexpr const char* name = "Neg";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1113,6 +1157,8 @@ struct Neg
struct ATan
{
static constexpr const char* name = "ATan";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1127,6 +1173,8 @@ struct ATan
struct Sin
{
static constexpr const char* name = "Sin";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1141,6 +1189,8 @@ struct Sin
struct ASinH
{
static constexpr const char* name = "ASinH";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1155,6 +1205,8 @@ struct ASinH
struct Cos
{
static constexpr const char* name = "Cos";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1169,6 +1221,8 @@ struct Cos
struct ACosH
{
static constexpr const char* name = "ACosH";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1183,6 +1237,8 @@ struct ACosH
struct Tan
{
static constexpr const char* name = "Tan";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1197,6 +1253,8 @@ struct Tan
struct ATanH
{
static constexpr const char* name = "ATanH";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1211,6 +1269,8 @@ struct ATanH
struct SinH
{
static constexpr const char* name = "SinH";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1225,6 +1285,8 @@ struct SinH
struct Ceil
{
static constexpr const char* name = "Ceil";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1239,6 +1301,8 @@ struct Ceil
struct Exp
{
static constexpr const char* name = "Exp";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1253,6 +1317,8 @@ struct Exp
struct CosH
{
static constexpr const char* name = "CosH";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1267,6 +1333,8 @@ struct CosH
struct Floor
{
static constexpr const char* name = "Floor";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1281,6 +1349,8 @@ struct Floor
struct Log
{
static constexpr const char* name = "Log";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1295,6 +1365,8 @@ struct Log
struct ASin
{
static constexpr const char* name = "ASin";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1309,6 +1381,8 @@ struct ASin
struct Rcp
{
static constexpr const char* name = "Rcp";
template <typename T>
__host__ __device__ void operator()(T& y, const T& x) const
{
@@ -1323,6 +1397,8 @@ struct Rcp
struct Swish
{
static constexpr const char* name = "Swish";
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename Y, typename X>
@@ -1352,6 +1428,8 @@ struct Swish
struct SoftRelu
{
static constexpr const char* name = "SoftRelu";
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
@@ -1380,6 +1458,8 @@ struct SoftRelu
struct Power
{
static constexpr const char* name = "Power";
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
@@ -1414,6 +1494,8 @@ struct Power
struct ClippedRelu
{
static constexpr const char* name = "ClippedRelu";
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
@@ -1443,6 +1525,8 @@ struct ClippedRelu
struct LeakyRelu
{
static constexpr const char* name = "LeakyRelu";
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
@@ -1470,6 +1554,8 @@ struct LeakyRelu
struct Elu
{
static constexpr const char* name = "Elu";
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
@@ -1497,6 +1583,8 @@ struct Elu
struct Logistic
{
static constexpr const char* name = "Logistic";
Logistic(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
@@ -1525,6 +1613,8 @@ struct Logistic
struct ConvInvscale
{
static constexpr const char* name = "ConvInvscale";
__host__ __device__ ConvInvscale(float scale_in = 1.f,
float scale_wei = 1.f,
float scale_out = 1.f)
@@ -1548,6 +1638,8 @@ struct ConvInvscale
struct ConvScale
{
static constexpr const char* name = "ConvScale";
__host__ __device__ ConvScale(float scale_in = 1.f,
float scale_wei = 1.f,
float scale_out = 1.f)
@@ -1571,6 +1663,8 @@ struct ConvScale
struct ConvScaleRelu
{
static constexpr const char* name = "ConvScaleRelu";
__host__ __device__ ConvScaleRelu(float scale_in = 1.f,
float scale_wei = 1.f,
float scale_out = 1.f)