Merge commit '6d709dac41409a339b82a83ea59e03fbb37c7005' into develop

This commit is contained in:
assistant-librarian[bot]
2025-10-25 15:11:17 +00:00
parent 4494721174
commit c29228adcf
6 changed files with 426 additions and 24 deletions

View File

@@ -40,6 +40,8 @@ consteval std::string_view type_name()
return "fp16";
else if constexpr(std::is_same_v<T, float>)
return "fp32";
else if constexpr(std::is_same_v<T, ck::tf32_t>)
return "tf32";
else if constexpr(std::is_same_v<T, double>)
return "fp64";
else if constexpr(std::is_same_v<T, int8_t>)
@@ -60,40 +62,25 @@ consteval std::string_view type_name()
template <typename T>
constexpr std::string_view layout_name()
{
if constexpr(requires {
if constexpr(std::is_base_of_v<ck_tile::tensor_layout::BaseTensorLayout, T> && requires {
{ T::name } -> std::convertible_to<std::string_view>;
})
return T::name;
else
static_assert(false, "layout type is missing name attribute");
static_assert(false,
"Layout type must derive from BaseTensorLayout and have name attribute");
}
// Convert element-wise operation types to string names
template <typename T>
constexpr std::string_view elementwise_op_name()
{
namespace element_wise = ck::tensor_operation::element_wise;
if constexpr(std::is_same_v<T, element_wise::PassThrough>)
return "PassThrough";
else if constexpr(std::is_same_v<T, element_wise::Scale>)
return "Scale";
else if constexpr(std::is_same_v<T, element_wise::Bilinear>)
return "Bilinear";
else if constexpr(std::is_same_v<T, element_wise::Add>)
return "Add";
else if constexpr(std::is_same_v<T, element_wise::AddRelu>)
return "AddRelu";
else if constexpr(std::is_same_v<T, element_wise::Relu>)
return "Relu";
else if constexpr(std::is_same_v<T, element_wise::BiasNormalizeInInferClamp>)
return "BiasNormalizeInInferClamp";
else if constexpr(std::is_same_v<T, element_wise::Clamp>)
return "Clamp";
else if constexpr(std::is_same_v<T, element_wise::AddClamp>)
return "AddClamp";
if constexpr(requires {
{ T::name } -> std::convertible_to<std::string_view>;
})
return T::name;
else
static_assert(false, "unknown_op");
static_assert(false, "Elementwise operation is missing name attribute");
}
// Convert ConvolutionForwardSpecialization enum to string

View File

@@ -17,7 +17,10 @@ endfunction()
add_ck_builder_test(test_conv_builder
test_conv_builder.cpp
test_instance_traits.cpp)
test_instance_traits.cpp
testing_utils.cpp)
add_ck_builder_test(test_get_instance_string
test_get_instance_string.cpp)
add_ck_builder_test(test_inline_diff test_inline_diff.cpp testing_utils.cpp)

View File

@@ -0,0 +1,52 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include "testing_utils.hpp"
namespace ck_tile::builder {
namespace {
TEST(InlineDiff, simpleColorDiff)
{
std::string str1{"hello"};
std::string str2{"hello"};
std::string str3{"world"};
// some easy tests
// you can veryfy the ungodly strings are meaningful by running echo -e "<string>"
EXPECT_THAT(test::inlineDiff(str1, str2, true), "hello");
EXPECT_THAT(test::inlineDiff(str1, str3, true),
"[\x1B[36mwor\x1B[0m|\x1B[35mhel\x1B[0m]l[\x1B[36md\x1B[0m|\x1B[35mo\x1B[0m]");
}
TEST(InlineDiff, noColorDiff)
{
std::string str1{"hello"};
std::string str2{"hello"};
std::string str3{"world"};
// some easy tests without color
EXPECT_THAT(test::inlineDiff(str1, str2, false), "hello");
EXPECT_THAT(test::inlineDiff(str1, str3, false), "[wor|hel]l[d|o]");
}
TEST(InlineDiff, complexColorDiff)
{
// now something more interesting
std::string str4{"this part has changed, this part has been left out, this part, this part has "
"an extra letter"};
std::string str5{
"this part has degeahc, this part has, this part added, this part has ana extra letter"};
EXPECT_THAT(
test::inlineDiff(str5, str4, true),
"this part has [\x1B[36mchanged\x1B[0m|\x1B[35mdegeahc\x1B[0m], this part has[\x1B[36m "
"been left out\x1B[0m|\x1B[35m\x1B[0m], this part[\x1B[36m\x1B[0m|\x1B[35m added\x1B[0m], "
"this part has an[\x1B[36m\x1B[0m|\x1B[35ma\x1B[0m] extra letter");
};
} // namespace
} // namespace ck_tile::builder

View File

@@ -0,0 +1,219 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <string>
#include <sstream>
#include <vector>
#include <algorithm>
#include <unistd.h>
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include "testing_utils.hpp"
namespace ck_tile::test {
namespace {
} // namespace
// Wagner-Fischer Algorithm for Computing Edit Distance and Inline Diff
//
// OUTPUT FORMAT: [expected|actual] for differences, plain text for matches
// Example: "hello world" vs "hello earth" → "hello [world|earth]"
//
// This function implements the Wagner-Fischer algorithm (1974), which is the classic
// dynamic programming solution for computing the minimum edit distance (Levenshtein distance)
// between two strings. The algorithm has O(n*m) time and space complexity.
//
// ALGORITHM OVERVIEW:
// 1. Build a 2D DP table where dp[i][j] represents the minimum edit distance
// between the first i characters of 'expected' and first j characters of 'actual'
// 2. Fill the table using the recurrence relation:
// dp[i][j] = min(
// dp[i-1][j] + 1, // deletion (remove char from expected)
// dp[i][j-1] + 1, // insertion (add char to expected)
// dp[i-1][j-1] + cost // substitution (cost=0 if chars match, 1 if different)
// )
// 3. Backtrack through the table to reconstruct the optimal edit sequence
//
// REFERENCES:
// - Wagner, R. A.; Fischer, M. J. (1974). "The String-to-String Correction Problem"
// - Also known as: Levenshtein distance, edit distance, string alignment
// - Similar to sequence alignment algorithms used in bioinformatics (Needleman-Wunsch)
std::string inlineDiff(const std::string& actual, const std::string& expected, bool use_color)
{
const char* EXPECTED_COLOR = use_color ? "\033[36m" : ""; // Cyan
const char* ACTUAL_COLOR = use_color ? "\033[35m" : ""; // Magenta
const char* RESET = use_color ? "\033[0m" : "";
const size_t n = expected.length(); // Length of expected string
const size_t m = actual.length(); // Length of actual string
// PHASE 1: Build the Dynamic Programming Table
// dp[i][j] = minimum edit distance between expected[0..i-1] and actual[0..j-1]
std::vector<std::vector<int>> dp(n + 1, std::vector<int>(m + 1));
// Base cases: transforming empty string to/from prefixes
for(size_t i = 0; i <= n; ++i)
{
dp[i][0] = i; // Delete i characters from expected to get empty string
}
for(size_t j = 0; j <= m; ++j)
{
dp[0][j] = j; // Insert j characters to empty string to get actual[0..j-1]
}
// Fill the DP table using the Wagner-Fischer recurrence relation
for(size_t i = 1; i <= n; ++i)
{
for(size_t j = 1; j <= m; ++j)
{
// Cost is 0 if characters match, 1 if they need substitution
int cost = (expected[i - 1] == actual[j - 1]) ? 0 : 1;
// Choose the minimum cost operation:
dp[i][j] = std::min({
dp[i - 1][j] + 1, // Deletion: remove expected[i-1]
dp[i][j - 1] + 1, // Insertion: add actual[j-1]
dp[i - 1][j - 1] + cost // Substitution/Match
});
}
}
// PHASE 2: Backtrack to Reconstruct the Optimal Edit Sequence
// We trace back from dp[n][m] to dp[0][0] to find which operations were used
std::vector<char> operations; // 'M'atch, 'S'ubstitution, 'I'nsertion, 'D'eletion
std::vector<std::pair<char, char>> diff_chars; // Character pairs for each operation
size_t i = n, j = m; // Start from bottom-right corner of DP table
while(i > 0 || j > 0)
{
// Determine which operation led to the current cell's value
int cost = (i > 0 && j > 0 && expected[i - 1] == actual[j - 1]) ? 0 : 1;
// Check if we came from diagonal (substitution/match)
if(i > 0 && j > 0 && dp[i][j] == dp[i - 1][j - 1] + cost)
{
if(cost == 0)
{
operations.push_back('M'); // Characters match
diff_chars.push_back({expected[i - 1], actual[j - 1]});
}
else
{
operations.push_back('S'); // Substitution needed
diff_chars.push_back({expected[i - 1], actual[j - 1]});
}
--i;
--j; // Move diagonally up-left
}
// Check if we came from left (insertion)
else if(j > 0 && dp[i][j] == dp[i][j - 1] + 1)
{
operations.push_back('I'); // Insertion: actual has extra character
diff_chars.push_back({'\0', actual[j - 1]});
--j; // Move left
}
// Must have come from above (deletion)
else if(i > 0 && dp[i][j] == dp[i - 1][j] + 1)
{
operations.push_back('D'); // Deletion: expected has extra character
diff_chars.push_back({expected[i - 1], '\0'});
--i; // Move up
}
}
// PHASE 3: Reverse and Build the Human-Readable Diff String
// Backtracking gives us operations in reverse order, so we reverse to get forward order
std::reverse(operations.begin(), operations.end());
std::reverse(diff_chars.begin(), diff_chars.end());
// Build the final diff string with color highlighting
std::ostringstream diff;
std::string expected_diff, actual_diff; // Accumulate consecutive differences
bool in_diff = false; // Track whether we're inside a diff section
for(size_t k = 0; k < operations.size(); ++k)
{
char op = operations[k];
char exp_char = diff_chars[k].first; // Expected character ('\0' for insertions)
char act_char = diff_chars[k].second; // Actual character ('\0' for deletions)
if(op == 'M') // Match - characters are identical
{
if(in_diff)
{
// Close the current diff section and output it
diff << "[" << EXPECTED_COLOR << expected_diff << RESET << "|" << ACTUAL_COLOR
<< actual_diff << RESET << "]";
expected_diff.clear();
actual_diff.clear();
in_diff = false;
}
diff << exp_char; // Output the matching character as-is
}
else // Difference (substitution, insertion, or deletion)
{
in_diff = true;
// Accumulate characters for the diff section
if(exp_char != '\0')
expected_diff += exp_char; // Add to expected side
if(act_char != '\0')
actual_diff += act_char; // Add to actual side
}
}
// Close any remaining diff section at the end
if(in_diff)
{
diff << "[" << EXPECTED_COLOR << expected_diff << RESET << "|" << ACTUAL_COLOR
<< actual_diff << RESET << "]";
}
return diff.str();
}
std::string formatInlineDiff(const std::string& actual, const std::string& expected)
{
return std::string("Inline diff: \"") + inlineDiff(actual, expected) + "\"";
}
// StringEqWithDiffMatcher implementation
StringEqWithDiffMatcher::StringEqWithDiffMatcher(const std::string& expected) : expected_(expected)
{
}
bool StringEqWithDiffMatcher::MatchAndExplain(std::string actual,
::testing::MatchResultListener* listener) const
{
if(actual == expected_)
{
return true;
}
// On failure, provide detailed diff information
if(listener->IsInterested())
{
*listener << "\n Diff: \"" << inlineDiff(actual, expected_) << "\"";
}
return false;
}
void StringEqWithDiffMatcher::DescribeTo(std::ostream* os) const
{
*os << "\"" << expected_ << "\"";
}
void StringEqWithDiffMatcher::DescribeNegationTo(std::ostream* os) const
{
*os << "is not equal to \"" << expected_ << "\"";
}
// Factory function for the StringEqWithDiff matcher
::testing::Matcher<std::string> StringEqWithDiff(const std::string& expected)
{
return ::testing::MakeMatcher(new StringEqWithDiffMatcher(expected));
}
} // namespace ck_tile::test

View File

@@ -0,0 +1,43 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include <string>
#include <sstream>
namespace ck_tile::test {
static bool isTerminalOutput() { return isatty(fileno(stdout)) || isatty(fileno(stderr)); }
// Returns a string highlighting differences between actual and expected.
// Differences are enclosed in brackets with actual and expected parts separated by '|'.
std::string inlineDiff(const std::string& actual,
const std::string& expected,
bool use_color = isTerminalOutput());
// A convenience alias for inlineDiff to improve readability in test assertions.
// Note that the function has O(n^2) complexity both in compute and in memory - do not use for very
// long strings
std::string formatInlineDiff(const std::string& actual, const std::string& expected);
// Gmock matcher for string equality with inline diff output on failure
class StringEqWithDiffMatcher : public ::testing::MatcherInterface<std::string>
{
public:
explicit StringEqWithDiffMatcher(const std::string& expected);
bool MatchAndExplain(std::string actual,
::testing::MatchResultListener* listener) const override;
void DescribeTo(std::ostream* os) const override;
void DescribeNegationTo(std::ostream* os) const override;
private:
std::string expected_;
};
// Factory function for the StringEqWithDiff matcher
::testing::Matcher<std::string> StringEqWithDiff(const std::string& expected);
} // namespace ck_tile::test

View File

@@ -349,6 +349,8 @@ CK_TILE_DEVICE bf8x4_t i4_to_bf8x4(int q)
struct PassThroughPack8
{
static constexpr const char* name = "PassThroughPack8";
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
@@ -388,6 +390,8 @@ struct PassThroughPack8
struct DequantPack8
{
static constexpr const char* name = "DequantPack8";
template <typename Y, typename X, typename Z>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x, const Z& z) const;
@@ -403,6 +407,8 @@ struct DequantPack8
struct PassThroughPack2
{
static constexpr const char* name = "PassThroughPack2";
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
@@ -429,6 +435,8 @@ struct PassThroughPack2
struct PassThrough
{
static constexpr const char* name = "PassThrough";
template <class T>
using raw_t = std::remove_cv_t<std::remove_reference_t<T>>;
@@ -465,6 +473,8 @@ struct PassThrough
struct AddScale
{
static constexpr const char* name = "AddScale";
template <typename E, typename... As>
CK_TILE_HOST_DEVICE constexpr void operator()(E& a, const As&... as) const
{
@@ -482,6 +492,8 @@ struct AddScale
struct MultiDMultiply
{
static constexpr const char* name = "MultiDMultiply";
template <typename E, typename C, typename... Ds>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
{
@@ -497,6 +509,8 @@ struct MultiDMultiply
struct MultiDAdd
{
static constexpr const char* name = "MultiDAdd";
template <typename E, typename C, typename... Ds>
CK_TILE_HOST_DEVICE auto operator()(E& e, const C& c, const Ds&... ds) const -> void
{
@@ -512,6 +526,8 @@ struct MultiDAdd
struct UnaryConvert
{
static constexpr const char* name = "UnaryConvert";
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
@@ -576,6 +592,8 @@ struct ConvertF8RNE
struct Scale
{
static constexpr const char* name = "Scale";
CK_TILE_HOST_DEVICE Scale(float scale = 1.f) : scale_(scale) {}
template <typename Y, typename X>
@@ -623,6 +641,8 @@ struct Scale
struct ScaleAndResetNaNToMinusInfinity
{
static constexpr const char* name = "ScaleAndResetNaNToMinusInfinity";
CK_TILE_HOST_DEVICE ScaleAndResetNaNToMinusInfinity(float scale) : scale_(scale) {}
template <typename Y, typename X>
@@ -639,6 +659,8 @@ struct ScaleAndResetNaNToMinusInfinity
struct UnaryDivide
{
static constexpr const char* name = "UnaryDivide";
CK_TILE_HOST_DEVICE UnaryDivide(const int32_t divider = 1) : divider_(divider) {}
template <typename T>
@@ -656,6 +678,8 @@ struct UnaryDivide
struct UnarySquare
{
static constexpr const char* name = "UnarySquare";
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const
{
@@ -673,6 +697,8 @@ struct UnarySquare
struct UnaryAbs
{
static constexpr const char* name = "UnaryAbs";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -687,6 +713,8 @@ struct UnaryAbs
struct UnarySqrt
{
static constexpr const char* name = "UnarySqrt";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -699,6 +727,8 @@ struct UnarySqrt
struct Relu
{
static constexpr const char* name = "Relu";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -725,6 +755,8 @@ struct Relu
// gpu code use lower accuracy "_ocml_exp_f32" and "rcp" function
struct FastGelu
{
static constexpr const char* name = "FastGelu";
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
@@ -842,6 +874,8 @@ struct FastGelu
struct FastGeluAsm
{
static constexpr const char* name = "FastGeluAsm";
template <typename Y, typename X>
CK_TILE_HOST void operator()(Y& y, const X& x) const;
@@ -943,6 +977,8 @@ struct FastGeluAsm
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct Gelu
{
static constexpr const char* name = "Gelu";
template <typename Y, typename X>
CK_TILE_HOST_DEVICE void operator()(Y& y, const X& x) const;
@@ -963,6 +999,8 @@ struct Gelu
struct Sigmoid
{
static constexpr const char* name = "Sigmoid";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -977,6 +1015,8 @@ struct Sigmoid
struct Silu
{
static constexpr const char* name = "Silu";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1066,6 +1106,8 @@ struct SiluAsm
struct TanH
{
static constexpr const char* name = "TanH";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1080,6 +1122,8 @@ struct TanH
struct ACos
{
static constexpr const char* name = "ACos";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1094,6 +1138,8 @@ struct ACos
struct Neg
{
static constexpr const char* name = "Neg";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1108,6 +1154,8 @@ struct Neg
struct ATan
{
static constexpr const char* name = "ATan";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1122,6 +1170,8 @@ struct ATan
struct Sin
{
static constexpr const char* name = "Sin";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1136,6 +1186,8 @@ struct Sin
struct ASinH
{
static constexpr const char* name = "ASinH";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1150,6 +1202,8 @@ struct ASinH
struct Cos
{
static constexpr const char* name = "Cos";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1164,6 +1218,8 @@ struct Cos
struct ACosH
{
static constexpr const char* name = "ACosH";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1178,6 +1234,8 @@ struct ACosH
struct Tan
{
static constexpr const char* name = "Tan";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1192,6 +1250,8 @@ struct Tan
struct ATanH
{
static constexpr const char* name = "ATanH";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1206,6 +1266,8 @@ struct ATanH
struct SinH
{
static constexpr const char* name = "SinH";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1220,6 +1282,8 @@ struct SinH
struct Ceil
{
static constexpr const char* name = "Ceil";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1234,6 +1298,8 @@ struct Ceil
struct Exp
{
static constexpr const char* name = "Exp";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1248,6 +1314,8 @@ struct Exp
struct CosH
{
static constexpr const char* name = "CosH";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1262,6 +1330,8 @@ struct CosH
struct Floor
{
static constexpr const char* name = "Floor";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1276,6 +1346,8 @@ struct Floor
struct Log
{
static constexpr const char* name = "Log";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1290,6 +1362,8 @@ struct Log
struct ASin
{
static constexpr const char* name = "ASin";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1304,6 +1378,8 @@ struct ASin
struct Rcp
{
static constexpr const char* name = "Rcp";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(T& y, const T& x) const
{
@@ -1318,6 +1394,8 @@ struct Rcp
struct Swish
{
static constexpr const char* name = "Swish";
Swish(float beta = 1.0f) : beta_(beta) {}
template <typename Y, typename X>
@@ -1340,6 +1418,8 @@ struct Swish
struct SoftRelu
{
static constexpr const char* name = "SoftRelu";
SoftRelu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
@@ -1358,6 +1438,8 @@ struct SoftRelu
struct Power
{
static constexpr const char* name = "Power";
Power(float alpha = 0.f, float beta = 1.f, float gamma = 2.f)
: alpha_(alpha), beta_(beta), gamma_(gamma){};
@@ -1381,6 +1463,8 @@ struct Power
struct ClippedRelu
{
static constexpr const char* name = "ClippedRelu";
ClippedRelu(float alpha = 0.f, float beta = 1.f) : alpha_(alpha), beta_(beta){};
template <typename T>
@@ -1400,6 +1484,8 @@ struct ClippedRelu
struct LeakyRelu
{
static constexpr const char* name = "LeakyRelu";
LeakyRelu(float alpha = 0.01f) : alpha_(alpha){};
template <typename T>
@@ -1417,6 +1503,8 @@ struct LeakyRelu
struct Elu
{
static constexpr const char* name = "Elu";
Elu(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
@@ -1434,6 +1522,8 @@ struct Elu
struct Logistic
{
static constexpr const char* name = "Logistic";
Logistic(float alpha = 1.f) : alpha_(alpha){};
template <typename T>
@@ -1452,6 +1542,8 @@ struct Logistic
struct ConvInvscale
{
static constexpr const char* name = "ConvInvscale";
CK_TILE_HOST_DEVICE
ConvInvscale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
@@ -1475,6 +1567,8 @@ struct ConvInvscale
struct ConvScale
{
static constexpr const char* name = "ConvScale";
CK_TILE_HOST_DEVICE
ConvScale(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
@@ -1498,6 +1592,8 @@ struct ConvScale
struct ConvScaleRelu
{
static constexpr const char* name = "ConvScaleRelu";
CK_TILE_HOST_DEVICE
ConvScaleRelu(float scale_in = 1.f, float scale_wei = 1.f, float scale_out = 1.f)
: scale_in_(scale_in), scale_wei_(scale_wei), scale_out_(scale_out)
@@ -1524,6 +1620,8 @@ struct ConvScaleRelu
template <typename DstType, typename SrcType>
struct Cast
{
static constexpr const char* name = "Cast";
template <typename T>
CK_TILE_HOST_DEVICE void operator()(DstType& y, const SrcType& x) const
{