[CK_TILE] Support for elementwise kernel (#2246)

* Elementwise kernel implementation

Co-authored-by: Sami Aario <samaario@amd.com>
Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: yashagar <yashagar@amd.com>

* Elementwise with generalized nDims

* Adding the n-ary input tensor feature

* Generalize dimensions on top of inputs

* Add TFLOPS + remove std usage for tuples

* 1D basecase optimization

* Cleanup code + refactoring to a common interface

* Generalize to unary and add an example

* Cleanup, refactoring and commenting

* Suggestions for LWPCK-3170: elementwise kernel improvements

* Clang-format: remod.py

* Replace InputTensorType with XDataType as the type of input_tensors

* Add Tuple::apply and use it in ElementWiseKernel::operator to call operation with the exact number of arguments in xs

* Move examples to folder 19_elementwise

* Add missing copyright headers and fix some existing ones

* Replace an assert with throw std::runtime_error in elementwise example

* Avoid reading the output by using make_static_distributed_tensor for y_tile

* Removed two unused includes

* No need to move windows to the next block when each workgroup processes a single tile

* Only copy input tensors to the device

* Use get_warp_size to obtain warp size, and use ceiling division for grid size also for the unary example

* Adding output strides to the kernel, transposition example and update the other examples

* Changes made by remod.py

* Use default template parameter values for memory operation and coherence in a call to make_naive_tensor_view

* Move binary operations to include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp

* Reuse generic reference binary/unary operation in examples + refactoring the transpose reference

* Fix comments in elementwise_example.cpp

- Refer to AMD terminology except when suggesting NVIDIA alternatives in parentheses
- ElementWiseTraits was renamed to ElementWiseShape
- Adopt suggestions made by Copilot when prompted to check for factual or typographical errors

* Simplify CMakeLists.txt and remove the unused variables this uncovers

* Rename a file and fix some copyright statements

* Changes made by script/clang-format-overwrite.sh

* Add basic unit test for ElementWiseKernel

* Remove left-over uninformative comment in apply unit test

* Changes made by clang-format-overwrite.sh

* fixup! Use default template parameter values for memory operation and coherence in a call to make_naive_tensor_view

* Clean up test_tuple_apply.cpp and test_elementwise_1d.cpp

* Use make_uniform_array_with_factory to define h_xs and d_xs_mems_owner as type std::array

* Use a DeviceMem constructor that calls get_element_space_size_in_bytes internally

* Move examples to folder 20_elementwise

* Reduced register pressure on the CK tile elementwise kernel + add 4d input example to be able benchmark against old CK

* Fix CLang formating

* Bump up the elementwise example folder number

* Elementwise: add padding + minor cleanup

* Add Vector Size inference + fix issue with wrong vectorization due to missing GuaranteedLastDimensionVectorStride setting in make_naive_tensor_view

* Add isSupportedArg to Elementwise kernel + addapt example and unit tests

* Fix clang-format on the unit test file

---------

Co-authored-by: Damien Lejeune <damien.lejeune@amd.com>
Co-authored-by: Sami Aario <samaario@amd.com>
Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: Aviral Goel <aviral.goel@amd.com>
This commit is contained in:
Yashvardhan Agarwal
2025-07-24 12:21:45 +03:00
committed by GitHub
parent 6681593864
commit 606b0cc947
23 changed files with 1509 additions and 6 deletions

View File

@@ -0,0 +1,223 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
using namespace ck_tile;
class TestCkTileTupleApply : public ::testing::Test
{
public:
// Test functors for different scenarios
struct AddFunction
{
template <typename... Args>
CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const
{
return (args + ...);
}
};
struct MultiplyFunction
{
template <typename... Args>
CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const
{
return (args * ...);
}
};
struct MaxFunction
{
template <typename T>
CK_TILE_HOST_DEVICE constexpr T operator()(T a) const
{
return a;
}
template <typename T, typename... Args>
CK_TILE_HOST_DEVICE constexpr T operator()(T a, Args... args) const
{
auto rest_max = operator()(args...);
return a > rest_max ? a : rest_max;
}
};
struct ReturnTupleFunction
{
template <typename... Args>
CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const
{
return make_tuple(args..., sizeof...(args));
}
};
};
TEST_F(TestCkTileTupleApply, BasicArithmetic)
{
// Test with simple arithmetic operations
auto t1 = make_tuple(1, 2, 3);
auto result1 = apply(AddFunction{}, t1);
EXPECT_EQ(result1, 6);
auto t2 = make_tuple(2, 3, 4, 5);
auto result2 = apply(MultiplyFunction{}, t2);
EXPECT_EQ(result2, 120);
}
TEST_F(TestCkTileTupleApply, SingleElement)
{
// Test with single element tuple
auto t1 = make_tuple(42);
auto result1 = apply(AddFunction{}, t1);
EXPECT_EQ(result1, 42);
auto result2 = apply(MultiplyFunction{}, t1);
EXPECT_EQ(result2, 42);
}
TEST_F(TestCkTileTupleApply, EmptyTuple)
{
// Test with empty tuple
auto t = tuple<>{};
auto result = apply([]() { return 100; }, t);
EXPECT_EQ(result, 100);
}
TEST_F(TestCkTileTupleApply, DifferentTypes)
{
// Test with different data types
auto t1 = make_tuple(1, 2.5f, 3.0);
auto result1 = apply(AddFunction{}, t1);
EXPECT_FLOAT_EQ(result1, 6.5f);
// Test with mixed integer and floating point
auto t2 = make_tuple(10, 0.5f);
auto result2 = apply(MultiplyFunction{}, t2);
EXPECT_FLOAT_EQ(result2, 5.0f);
}
TEST_F(TestCkTileTupleApply, ReturnTuple)
{
// Test function that returns a tuple
auto t = make_tuple(1, 2, 3);
auto result = apply(ReturnTupleFunction{}, t);
EXPECT_EQ(result.get<0>(), 1);
EXPECT_EQ(result.get<1>(), 2);
EXPECT_EQ(result.get<2>(), 3);
EXPECT_EQ(result.get<3>(), 3); // size
}
TEST_F(TestCkTileTupleApply, LambdaFunction)
{
// Test with lambda functions
auto t1 = make_tuple(5, 10, 15);
auto result1 = apply([](auto a, auto b, auto c) { return a + b + c; }, t1);
EXPECT_EQ(result1, 30);
// Test lambda with capture
int multiplier = 2;
auto result2 =
apply([multiplier](auto a, auto b) { return (a + b) * multiplier; }, make_tuple(3, 7));
EXPECT_EQ(result2, 20);
}
TEST_F(TestCkTileTupleApply, ConstexprContext)
{
// Test in constexpr context
constexpr auto t = make_tuple(2, 3, 4);
constexpr auto result = apply(MultiplyFunction{}, t);
static_assert(result == 24, "Constexpr apply should work");
EXPECT_EQ(result, 24);
}
TEST_F(TestCkTileTupleApply, ReferenceTypes)
{
// Test with reference types using tie
int a = 1, b = 2, c = 3;
auto ref_tuple = tie(a, b, c);
// Function that modifies references
apply(
[](auto& x, auto& y, auto& z) {
x += 10;
y += 20;
z += 30;
},
ref_tuple);
EXPECT_EQ(a, 11);
EXPECT_EQ(b, 22);
EXPECT_EQ(c, 33);
}
TEST_F(TestCkTileTupleApply, MoveSemantics)
{
// Test with move semantics
auto t = make_tuple(1, 2, 3);
auto result = apply(AddFunction{}, std::move(t));
EXPECT_EQ(result, 6);
}
TEST_F(TestCkTileTupleApply, NumberTypes)
{
// Test with ck_tile::number types
auto t = make_tuple(number<1>{}, number<2>{}, number<3>{});
auto result = apply([](auto a, auto b, auto c) { return a + b + c; }, t);
EXPECT_EQ(result, 6);
}
TEST_F(TestCkTileTupleApply, ElementwiseOperation)
{
// Test simulating elementwise operations
auto input1 = make_tuple(1.0f, 2.0f, 3.0f);
auto input2 = make_tuple(4.0f, 5.0f, 6.0f);
auto add_elementwise = [](const auto& a, const auto& b) {
return apply(
[&b](auto... args_a) {
return apply(
[args_a...](auto... args_b) { return make_tuple((args_a + args_b)...); }, b);
},
a);
};
auto result = add_elementwise(input1, input2);
EXPECT_FLOAT_EQ(result.get<0>(), 5.0f);
EXPECT_FLOAT_EQ(result.get<1>(), 7.0f);
EXPECT_FLOAT_EQ(result.get<2>(), 9.0f);
}
template <typename T>
class TestCkTileTupleApplySize : public TestCkTileTupleApply
{
protected:
static constexpr int Size = T::value;
};
using TupleSizes = ::testing::Types<std::integral_constant<int, 1>,
std::integral_constant<int, 2>,
std::integral_constant<int, 3>,
std::integral_constant<int, 4>,
std::integral_constant<int, 8>,
std::integral_constant<int, 16>>;
TYPED_TEST_SUITE(TestCkTileTupleApplySize, TupleSizes);
TYPED_TEST(TestCkTileTupleApplySize, GeneratedTupleSum)
{
constexpr int N = TypeParam::value;
// Generate tuple with values 1, 2, 3, ..., N
constexpr auto t = generate_tuple([](auto i) { return i.value + 1; }, number<N>{});
// Sum all elements
constexpr auto result = apply(TestCkTileTupleApply::AddFunction{}, t);
// Expected sum: 1 + 2 + ... + N = N*(N+1)/2
constexpr int expected = N * (N + 1) / 2;
static_assert(result == expected);
}