mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 10:37:44 +00:00
Merge commit '6d709dac41409a339b82a83ea59e03fbb37c7005' into develop
This commit is contained in:
@@ -40,6 +40,8 @@ consteval std::string_view type_name()
|
||||
return "fp16";
|
||||
else if constexpr(std::is_same_v<T, float>)
|
||||
return "fp32";
|
||||
else if constexpr(std::is_same_v<T, ck::tf32_t>)
|
||||
return "tf32";
|
||||
else if constexpr(std::is_same_v<T, double>)
|
||||
return "fp64";
|
||||
else if constexpr(std::is_same_v<T, int8_t>)
|
||||
@@ -60,40 +62,25 @@ consteval std::string_view type_name()
|
||||
template <typename T>
|
||||
constexpr std::string_view layout_name()
|
||||
{
|
||||
if constexpr(requires {
|
||||
if constexpr(std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> && requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
})
|
||||
return T::name;
|
||||
else
|
||||
static_assert(false, "layout type is missing name attribute");
|
||||
static_assert(false,
|
||||
"Layout type must derive from BaseTensorLayout and have name attribute");
|
||||
}
|
||||
|
||||
// Convert element-wise operation types to string names
|
||||
template <typename T>
|
||||
constexpr std::string_view elementwise_op_name()
|
||||
{
|
||||
namespace element_wise = ck::tensor_operation::element_wise;
|
||||
|
||||
if constexpr(std::is_same_v<T, element_wise::PassThrough>)
|
||||
return "PassThrough";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Scale>)
|
||||
return "Scale";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Bilinear>)
|
||||
return "Bilinear";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Add>)
|
||||
return "Add";
|
||||
else if constexpr(std::is_same_v<T, element_wise::AddRelu>)
|
||||
return "AddRelu";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Relu>)
|
||||
return "Relu";
|
||||
else if constexpr(std::is_same_v<T, element_wise::BiasNormalizeInInferClamp>)
|
||||
return "BiasNormalizeInInferClamp";
|
||||
else if constexpr(std::is_same_v<T, element_wise::Clamp>)
|
||||
return "Clamp";
|
||||
else if constexpr(std::is_same_v<T, element_wise::AddClamp>)
|
||||
return "AddClamp";
|
||||
if constexpr(requires {
|
||||
{ T::name } -> std::convertible_to<std::string_view>;
|
||||
})
|
||||
return T::name;
|
||||
else
|
||||
static_assert(false, "unknown_op");
|
||||
static_assert(false, "Elementwise operation is missing name attribute");
|
||||
}
|
||||
|
||||
// Convert ConvolutionForwardSpecialization enum to string
|
||||
|
||||
@@ -17,7 +17,10 @@ endfunction()
|
||||
|
||||
add_ck_builder_test(test_conv_builder
|
||||
test_conv_builder.cpp
|
||||
test_instance_traits.cpp)
|
||||
test_instance_traits.cpp
|
||||
testing_utils.cpp)
|
||||
|
||||
add_ck_builder_test(test_get_instance_string
|
||||
test_get_instance_string.cpp)
|
||||
|
||||
add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp)
|
||||
|
||||
52
experimental/builder/test/test_inline_diff.cpp
Normal file
52
experimental/builder/test/test_inline_diff.cpp
Normal file
@@ -0,0 +1,52 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
namespace ck_tile::builder {
|
||||
namespace {
|
||||
|
||||
TEST(InlineDiff, simpleColorDiff)
|
||||
{
|
||||
std::string str1{"hello"};
|
||||
std::string str2{"hello"};
|
||||
std::string str3{"world"};
|
||||
|
||||
// some easy tests
|
||||
// you can veryfy the ungodly strings are meaningful by running echo -e "<string>"
|
||||
EXPECT_THAT(test::inlineDiff(str1, str2, true), "hello");
|
||||
EXPECT_THAT(test::inlineDiff(str1, str3, true),
|
||||
"[\x1B[36mwor\x1B[0m|\x1B[35mhel\x1B[0m]l[\x1B[36md\x1B[0m|\x1B[35mo\x1B[0m]");
|
||||
}
|
||||
|
||||
TEST(InlineDiff, noColorDiff)
|
||||
{
|
||||
std::string str1{"hello"};
|
||||
std::string str2{"hello"};
|
||||
std::string str3{"world"};
|
||||
|
||||
// some easy tests without color
|
||||
EXPECT_THAT(test::inlineDiff(str1, str2, false), "hello");
|
||||
EXPECT_THAT(test::inlineDiff(str1, str3, false), "[wor|hel]l[d|o]");
|
||||
}
|
||||
|
||||
TEST(InlineDiff, complexColorDiff)
|
||||
{
|
||||
|
||||
// now something more interesting
|
||||
std::string str4{"this part has changed, this part has been left out, this part, this part has "
|
||||
"an extra letter"};
|
||||
std::string str5{
|
||||
"this part has degeahc, this part has, this part added, this part has ana extra letter"};
|
||||
|
||||
EXPECT_THAT(
|
||||
test::inlineDiff(str5, str4, true),
|
||||
"this part has [\x1B[36mchanged\x1B[0m|\x1B[35mdegeahc\x1B[0m], this part has[\x1B[36m "
|
||||
"been left out\x1B[0m|\x1B[35m\x1B[0m], this part[\x1B[36m\x1B[0m|\x1B[35m added\x1B[0m], "
|
||||
"this part has an[\x1B[36m\x1B[0m|\x1B[35ma\x1B[0m] extra letter");
|
||||
};
|
||||
|
||||
} // namespace
|
||||
} // namespace ck_tile::builder
|
||||
219
experimental/builder/test/testing_utils.cpp
Normal file
219
experimental/builder/test/testing_utils.cpp
Normal file
@@ -0,0 +1,219 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include <unistd.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include "testing_utils.hpp"
|
||||
|
||||
namespace ck_tile::test {
|
||||
|
||||
namespace {
|
||||
|
||||
} // namespace
|
||||
|
||||
// Wagner-Fischer Algorithm for Computing Edit Distance and Inline Diff
|
||||
//
|
||||
// OUTPUT FORMAT: [expected|actual] for differences, plain text for matches
|
||||
// Example: "hello world" vs "hello earth" → "hello [world|earth]"
|
||||
//
|
||||
// This function implements the Wagner-Fischer algorithm (1974), which is the classic
|
||||
// dynamic programming solution for computing the minimum edit distance (Levenshtein distance)
|
||||
// between two strings. The algorithm has O(n*m) time and space complexity.
|
||||
//
|
||||
// ALGORITHM OVERVIEW:
|
||||
// 1. Build a 2D DP table where dp[i][j] represents the minimum edit distance
|
||||
// between the first i characters of 'expected' and first j characters of 'actual'
|
||||
// 2. Fill the table using the recurrence relation:
|
||||
// dp[i][j] = min(
|
||||
// dp[i-1][j] + 1, // deletion (remove char from expected)
|
||||
// dp[i][j-1] + 1, // insertion (add char to expected)
|
||||
// dp[i-1][j-1] + cost // substitution (cost=0 if chars match, 1 if different)
|
||||
// )
|
||||
// 3. Backtrack through the table to reconstruct the optimal edit sequence
|
||||
//
|
||||
// REFERENCES:
|
||||
// - Wagner, R. A.; Fischer, M. J. (1974). "The String-to-String Correction Problem"
|
||||
// - Also known as: Levenshtein distance, edit distance, string alignment
|
||||
// - Similar to sequence alignment algorithms used in bioinformatics (Needleman-Wunsch)
|
||||
std::string inlineDiff(const std::string& actual, const std::string& expected, bool use_color)
|
||||
{
|
||||
|
||||
const char* EXPECTED_COLOR = use_color ? "\033[36m" : ""; // Cyan
|
||||
const char* ACTUAL_COLOR = use_color ? "\033[35m" : ""; // Magenta
|
||||
const char* RESET = use_color ? "\033[0m" : "";
|
||||
|
||||
const size_t n = expected.length(); // Length of expected string
|
||||
const size_t m = actual.length(); // Length of actual string
|
||||
|
||||
// PHASE 1: Build the Dynamic Programming Table
|
||||
// dp[i][j] = minimum edit distance between expected[0..i-1] and actual[0..j-1]
|
||||
std::vector<std::vector<int>> dp(n + 1, std::vector<int>(m + 1));
|
||||
|
||||
// Base cases: transforming empty string to/from prefixes
|
||||
for(size_t i = 0; i <= n; ++i)
|
||||
{
|
||||
dp[i][0] = i; // Delete i characters from expected to get empty string
|
||||
}
|
||||
for(size_t j = 0; j <= m; ++j)
|
||||
{
|
||||
dp[0][j] = j; // Insert j characters to empty string to get actual[0..j-1]
|
||||
}
|
||||
|
||||
// Fill the DP table using the Wagner-Fischer recurrence relation
|
||||
for(size_t i = 1; i <= n; ++i)
|
||||
{
|
||||
for(size_t j = 1; j <= m; ++j)
|
||||
{
|
||||
// Cost is 0 if characters match, 1 if they need substitution
|
||||
int cost = (expected[i - 1] == actual[j - 1]) ? 0 : 1;
|
||||
|
||||
// Choose the minimum cost operation:
|
||||
dp[i][j] = std::min({
|
||||
dp[i - 1][j] + 1, // Deletion: remove expected[i-1]
|
||||
dp[i][j - 1] + 1, // Insertion: add actual[j-1]
|
||||
dp[i - 1][j - 1] + cost // Substitution/Match
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// PHASE 2: Backtrack to Reconstruct the Optimal Edit Sequence
|
||||
// We trace back from dp[n][m] to dp[0][0] to find which operations were used
|
||||
std::vector<char> operations; // 'M'atch, 'S'ubstitution, 'I'nsertion, 'D'eletion
|
||||
std::vector<std::pair<char, char>> diff_chars; // Character pairs for each operation
|
||||
|
||||
size_t i = n, j = m; // Start from bottom-right corner of DP table
|
||||
while(i > 0 || j > 0)
|
||||
{
|
||||
// Determine which operation led to the current cell's value
|
||||
int cost = (i > 0 && j > 0 && expected[i - 1] == actual[j - 1]) ? 0 : 1;
|
||||
|
||||
// Check if we came from diagonal (substitution/match)
|
||||
if(i > 0 && j > 0 && dp[i][j] == dp[i - 1][j - 1] + cost)
|
||||
{
|
||||
if(cost == 0)
|
||||
{
|
||||
operations.push_back('M'); // Characters match
|
||||
diff_chars.push_back({expected[i - 1], actual[j - 1]});
|
||||
}
|
||||
else
|
||||
{
|
||||
operations.push_back('S'); // Substitution needed
|
||||
diff_chars.push_back({expected[i - 1], actual[j - 1]});
|
||||
}
|
||||
--i;
|
||||
--j; // Move diagonally up-left
|
||||
}
|
||||
// Check if we came from left (insertion)
|
||||
else if(j > 0 && dp[i][j] == dp[i][j - 1] + 1)
|
||||
{
|
||||
operations.push_back('I'); // Insertion: actual has extra character
|
||||
diff_chars.push_back({'\0', actual[j - 1]});
|
||||
--j; // Move left
|
||||
}
|
||||
// Must have come from above (deletion)
|
||||
else if(i > 0 && dp[i][j] == dp[i - 1][j] + 1)
|
||||
{
|
||||
operations.push_back('D'); // Deletion: expected has extra character
|
||||
diff_chars.push_back({expected[i - 1], '\0'});
|
||||
--i; // Move up
|
||||
}
|
||||
}
|
||||
|
||||
// PHASE 3: Reverse and Build the Human-Readable Diff String
|
||||
// Backtracking gives us operations in reverse order, so we reverse to get forward order
|
||||
std::reverse(operations.begin(), operations.end());
|
||||
std::reverse(diff_chars.begin(), diff_chars.end());
|
||||
|
||||
// Build the final diff string with color highlighting
|
||||
std::ostringstream diff;
|
||||
std::string expected_diff, actual_diff; // Accumulate consecutive differences
|
||||
bool in_diff = false; // Track whether we're inside a diff section
|
||||
|
||||
for(size_t k = 0; k < operations.size(); ++k)
|
||||
{
|
||||
char op = operations[k];
|
||||
char exp_char = diff_chars[k].first; // Expected character ('\0' for insertions)
|
||||
char act_char = diff_chars[k].second; // Actual character ('\0' for deletions)
|
||||
|
||||
if(op == 'M') // Match - characters are identical
|
||||
{
|
||||
if(in_diff)
|
||||
{
|
||||
// Close the current diff section and output it
|
||||
diff << "[" << EXPECTED_COLOR << expected_diff << RESET << "|" << ACTUAL_COLOR
|
||||
<< actual_diff << RESET << "]";
|
||||
expected_diff.clear();
|
||||
actual_diff.clear();
|
||||
in_diff = false;
|
||||
}
|
||||
diff << exp_char; // Output the matching character as-is
|
||||
}
|
||||
else // Difference (substitution, insertion, or deletion)
|
||||
{
|
||||
in_diff = true;
|
||||
// Accumulate characters for the diff section
|
||||
if(exp_char != '\0')
|
||||
expected_diff += exp_char; // Add to expected side
|
||||
if(act_char != '\0')
|
||||
actual_diff += act_char; // Add to actual side
|
||||
}
|
||||
}
|
||||
|
||||
// Close any remaining diff section at the end
|
||||
if(in_diff)
|
||||
{
|
||||
diff << "[" << EXPECTED_COLOR << expected_diff << RESET << "|" << ACTUAL_COLOR
|
||||
<< actual_diff << RESET << "]";
|
||||
}
|
||||
|
||||
return diff.str();
|
||||
}
|
||||
|
||||
std::string formatInlineDiff(const std::string& actual, const std::string& expected)
|
||||
{
|
||||
return std::string("Inline diff: \"") + inlineDiff(actual, expected) + "\"";
|
||||
}
|
||||
|
||||
// StringEqWithDiffMatcher implementation
|
||||
StringEqWithDiffMatcher::StringEqWithDiffMatcher(const std::string& expected) : expected_(expected)
|
||||
{
|
||||
}
|
||||
|
||||
bool StringEqWithDiffMatcher::MatchAndExplain(std::string actual,
|
||||
::testing::MatchResultListener* listener) const
|
||||
{
|
||||
if(actual == expected_)
|
||||
{
|
||||
return true;
|
||||
}
|
||||
|
||||
// On failure, provide detailed diff information
|
||||
if(listener->IsInterested())
|
||||
{
|
||||
*listener << "\n Diff: \"" << inlineDiff(actual, expected_) << "\"";
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void StringEqWithDiffMatcher::DescribeTo(std::ostream* os) const
|
||||
{
|
||||
*os << "\"" << expected_ << "\"";
|
||||
}
|
||||
|
||||
void StringEqWithDiffMatcher::DescribeNegationTo(std::ostream* os) const
|
||||
{
|
||||
*os << "is not equal to \"" << expected_ << "\"";
|
||||
}
|
||||
|
||||
// Factory function for the StringEqWithDiff matcher
|
||||
::testing::Matcher<std::string> StringEqWithDiff(const std::string& expected)
|
||||
{
|
||||
return ::testing::MakeMatcher(new StringEqWithDiffMatcher(expected));
|
||||
}
|
||||
|
||||
} // namespace ck_tile::test
|
||||
43
experimental/builder/test/testing_utils.hpp
Normal file
43
experimental/builder/test/testing_utils.hpp
Normal file
@@ -0,0 +1,43 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <gmock/gmock.h>
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
|
||||
namespace ck_tile::test {
|
||||
|
||||
static bool isTerminalOutput() { return isatty(fileno(stdout)) || isatty(fileno(stderr)); }
|
||||
|
||||
// Returns a string highlighting differences between actual and expected.
|
||||
// Differences are enclosed in brackets with actual and expected parts separated by '|'.
|
||||
std::string inlineDiff(const std::string& actual,
|
||||
const std::string& expected,
|
||||
bool use_color = isTerminalOutput());
|
||||
|
||||
// A convenience alias for inlineDiff to improve readability in test assertions.
|
||||
// Note that the function has O(n^2) complexity both in compute and in memory - do not use for very
|
||||
// long strings
|
||||
std::string formatInlineDiff(const std::string& actual, const std::string& expected);
|
||||
|
||||
// Gmock matcher for string equality with inline diff output on failure
|
||||
class StringEqWithDiffMatcher : public ::testing::MatcherInterface<std::string>
|
||||
{
|
||||
public:
|
||||
explicit StringEqWithDiffMatcher(const std::string& expected);
|
||||
|
||||
bool MatchAndExplain(std::string actual,
|
||||
::testing::MatchResultListener* listener) const override;
|
||||
|
||||
void DescribeTo(std::ostream* os) const override;
|
||||
void DescribeNegationTo(std::ostream* os) const override;
|
||||
|
||||
private:
|
||||
std::string expected_;
|
||||
};
|
||||
|
||||
// Factory function for the StringEqWithDiff matcher
|
||||
::testing::Matcher<std::string> StringEqWithDiff(const std::string& expected);
|
||||
|
||||
} // namespace ck_tile::test
|
||||
@@ -349,6 +349,8 @@ CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q)
|
||||
|
||||
struct PassThroughPack8
|
||||
{
|
||||
static constexpr const char* name = "PassThroughPack8";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -388,6 +390,8 @@ struct PassThroughPack8
|
||||
|
||||
struct DequantPack8
|
||||
{
|
||||
static constexpr const char* name = "DequantPack8";
|
||||
|
||||
template <typename Y, typename X, typename Z>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const;
|
||||
|
||||
@@ -403,6 +407,8 @@ struct DequantPack8
|
||||
|
||||
struct PassThroughPack2
|
||||
{
|
||||
static constexpr const char* name = "PassThroughPack2";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -429,6 +435,8 @@ struct PassThroughPack2
|
||||
|
||||
struct PassThrough
|
||||
{
|
||||
static constexpr const char* name = "PassThrough";
|
||||
|
||||
template <class T>
|
||||
using raw_t = std::remove_cv_t<std::remove_reference_t<T>>;
|
||||
|
||||
@@ -465,6 +473,8 @@ struct PassThrough
|
||||
|
||||
struct AddScale
|
||||
{
|
||||
static constexpr const char* name = "AddScale";
|
||||
|
||||
template <typename E, typename... As>
|
||||
CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const As&... as) const
|
||||
{
|
||||
@@ -482,6 +492,8 @@ struct AddScale
|
||||
|
||||
struct MultiDMultiply
|
||||
{
|
||||
static constexpr const char* name = "MultiDMultiply";
|
||||
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
@@ -497,6 +509,8 @@ struct MultiDMultiply
|
||||
|
||||
struct MultiDAdd
|
||||
{
|
||||
static constexpr const char* name = "MultiDAdd";
|
||||
|
||||
template <typename E, typename C, typename... Ds>
|
||||
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
|
||||
{
|
||||
@@ -512,6 +526,8 @@ struct MultiDAdd
|
||||
|
||||
struct UnaryConvert
|
||||
{
|
||||
static constexpr const char* name = "UnaryConvert";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
|
||||
{
|
||||
@@ -576,6 +592,8 @@ struct ConvertF8RNE
|
||||
|
||||
struct Scale
|
||||
{
|
||||
static constexpr const char* name = "Scale";
|
||||
|
||||
CK_TILE_HOST_DEVICE Scale(float scale = 1.f) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -623,6 +641,8 @@ struct Scale
|
||||
|
||||
struct ScaleAndResetNaNToMinusInfinity
|
||||
{
|
||||
static constexpr const char* name = "ScaleAndResetNaNToMinusInfinity";
|
||||
|
||||
CK_TILE_HOST_DEVICE ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -639,6 +659,8 @@ struct ScaleAndResetNaNToMinusInfinity
|
||||
|
||||
struct UnaryDivide
|
||||
{
|
||||
static constexpr const char* name = "UnaryDivide";
|
||||
|
||||
CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
|
||||
|
||||
template <typename T>
|
||||
@@ -656,6 +678,8 @@ struct UnaryDivide
|
||||
|
||||
struct UnarySquare
|
||||
{
|
||||
static constexpr const char* name = "UnarySquare";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
|
||||
{
|
||||
@@ -673,6 +697,8 @@ struct UnarySquare
|
||||
|
||||
struct UnaryAbs
|
||||
{
|
||||
static constexpr const char* name = "UnaryAbs";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -687,6 +713,8 @@ struct UnaryAbs
|
||||
|
||||
struct UnarySqrt
|
||||
{
|
||||
static constexpr const char* name = "UnarySqrt";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -699,6 +727,8 @@ struct UnarySqrt
|
||||
|
||||
struct Relu
|
||||
{
|
||||
static constexpr const char* name = "Relu";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -725,6 +755,8 @@ struct Relu
|
||||
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
|
||||
struct FastGelu
|
||||
{
|
||||
static constexpr const char* name = "FastGelu";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -842,6 +874,8 @@ struct FastGelu
|
||||
|
||||
struct FastGeluAsm
|
||||
{
|
||||
static constexpr const char* name = "FastGeluAsm";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -943,6 +977,8 @@ struct FastGeluAsm
|
||||
// y = 0.5*x*(1+erf(x/sqrt(2)))
|
||||
struct Gelu
|
||||
{
|
||||
static constexpr const char* name = "Gelu";
|
||||
|
||||
template <typename Y, typename X>
|
||||
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
|
||||
|
||||
@@ -963,6 +999,8 @@ struct Gelu
|
||||
|
||||
struct Sigmoid
|
||||
{
|
||||
static constexpr const char* name = "Sigmoid";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -977,6 +1015,8 @@ struct Sigmoid
|
||||
|
||||
struct Silu
|
||||
{
|
||||
static constexpr const char* name = "Silu";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1066,6 +1106,8 @@ struct SiluAsm
|
||||
|
||||
struct TanH
|
||||
{
|
||||
static constexpr const char* name = "TanH";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1080,6 +1122,8 @@ struct TanH
|
||||
|
||||
struct ACos
|
||||
{
|
||||
static constexpr const char* name = "ACos";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1094,6 +1138,8 @@ struct ACos
|
||||
|
||||
struct Neg
|
||||
{
|
||||
static constexpr const char* name = "Neg";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1108,6 +1154,8 @@ struct Neg
|
||||
|
||||
struct ATan
|
||||
{
|
||||
static constexpr const char* name = "ATan";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1122,6 +1170,8 @@ struct ATan
|
||||
|
||||
struct Sin
|
||||
{
|
||||
static constexpr const char* name = "Sin";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1136,6 +1186,8 @@ struct Sin
|
||||
|
||||
struct ASinH
|
||||
{
|
||||
static constexpr const char* name = "ASinH";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1150,6 +1202,8 @@ struct ASinH
|
||||
|
||||
struct Cos
|
||||
{
|
||||
static constexpr const char* name = "Cos";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1164,6 +1218,8 @@ struct Cos
|
||||
|
||||
struct ACosH
|
||||
{
|
||||
static constexpr const char* name = "ACosH";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1178,6 +1234,8 @@ struct ACosH
|
||||
|
||||
struct Tan
|
||||
{
|
||||
static constexpr const char* name = "Tan";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1192,6 +1250,8 @@ struct Tan
|
||||
|
||||
struct ATanH
|
||||
{
|
||||
static constexpr const char* name = "ATanH";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1206,6 +1266,8 @@ struct ATanH
|
||||
|
||||
struct SinH
|
||||
{
|
||||
static constexpr const char* name = "SinH";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1220,6 +1282,8 @@ struct SinH
|
||||
|
||||
struct Ceil
|
||||
{
|
||||
static constexpr const char* name = "Ceil";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1234,6 +1298,8 @@ struct Ceil
|
||||
|
||||
struct Exp
|
||||
{
|
||||
static constexpr const char* name = "Exp";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1248,6 +1314,8 @@ struct Exp
|
||||
|
||||
struct CosH
|
||||
{
|
||||
static constexpr const char* name = "CosH";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1262,6 +1330,8 @@ struct CosH
|
||||
|
||||
struct Floor
|
||||
{
|
||||
static constexpr const char* name = "Floor";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1276,6 +1346,8 @@ struct Floor
|
||||
|
||||
struct Log
|
||||
{
|
||||
static constexpr const char* name = "Log";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1290,6 +1362,8 @@ struct Log
|
||||
|
||||
struct ASin
|
||||
{
|
||||
static constexpr const char* name = "ASin";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1304,6 +1378,8 @@ struct ASin
|
||||
|
||||
struct Rcp
|
||||
{
|
||||
static constexpr const char* name = "Rcp";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
|
||||
{
|
||||
@@ -1318,6 +1394,8 @@ struct Rcp
|
||||
|
||||
struct Swish
|
||||
{
|
||||
static constexpr const char* name = "Swish";
|
||||
|
||||
Swish(float beta = 1.0f) : beta_(beta) {}
|
||||
|
||||
template <typename Y, typename X>
|
||||
@@ -1340,6 +1418,8 @@ struct Swish
|
||||
|
||||
struct SoftRelu
|
||||
{
|
||||
static constexpr const char* name = "SoftRelu";
|
||||
|
||||
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1358,6 +1438,8 @@ struct SoftRelu
|
||||
|
||||
struct Power
|
||||
{
|
||||
static constexpr const char* name = "Power";
|
||||
|
||||
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
|
||||
: alpha_(alpha), beta_(beta), gamma_(gamma){};
|
||||
|
||||
@@ -1381,6 +1463,8 @@ struct Power
|
||||
|
||||
struct ClippedRelu
|
||||
{
|
||||
static constexpr const char* name = "ClippedRelu";
|
||||
|
||||
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1400,6 +1484,8 @@ struct ClippedRelu
|
||||
|
||||
struct LeakyRelu
|
||||
{
|
||||
static constexpr const char* name = "LeakyRelu";
|
||||
|
||||
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1417,6 +1503,8 @@ struct LeakyRelu
|
||||
|
||||
struct Elu
|
||||
{
|
||||
static constexpr const char* name = "Elu";
|
||||
|
||||
Elu(float alpha = 1.f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1434,6 +1522,8 @@ struct Elu
|
||||
|
||||
struct Logistic
|
||||
{
|
||||
static constexpr const char* name = "Logistic";
|
||||
|
||||
Logistic(float alpha = 1.f) : alpha_(alpha){};
|
||||
|
||||
template <typename T>
|
||||
@@ -1452,6 +1542,8 @@ struct Logistic
|
||||
|
||||
struct ConvInvscale
|
||||
{
|
||||
static constexpr const char* name = "ConvInvscale";
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
ConvInvscale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
|
||||
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
|
||||
@@ -1475,6 +1567,8 @@ struct ConvInvscale
|
||||
|
||||
struct ConvScale
|
||||
{
|
||||
static constexpr const char* name = "ConvScale";
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
ConvScale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
|
||||
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
|
||||
@@ -1498,6 +1592,8 @@ struct ConvScale
|
||||
|
||||
struct ConvScaleRelu
|
||||
{
|
||||
static constexpr const char* name = "ConvScaleRelu";
|
||||
|
||||
CK_TILE_HOST_DEVICE
|
||||
ConvScaleRelu(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
|
||||
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
|
||||
@@ -1524,6 +1620,8 @@ struct ConvScaleRelu
|
||||
template <typename DstType, typename SrcType>
|
||||
struct Cast
|
||||
{
|
||||
static constexpr const char* name = "Cast";
|
||||
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE void operator()(DstType& y, const SrcType& x) const
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user