[CK_TILE] Support for elementwise kernel (#2246)

* Elementwise kernel implementation

Co-authored-by: Sami Aario <samaario@amd.com>
Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: yashagar <yashagar@amd.com>

* Elementwise with generalized nDims

* Adding the n-ary input tensor feature

* Generalize dimensions on top of inputs

* Add TFLOPS + remove std usage for tuples

* 1D basecase optimization

* Cleanup code + refactoring to a common interface

* Generalize to unary and add an example

* Cleanup, refactoring and commenting

* Suggestions for LWPCK-3170: elementwise kernel improvements

* Clang-format: remod.py

* Replace InputTensorType with XDataType as the type of input_tensors

* Add Tuple::apply and use it in ElementWiseKernel::operator to call operation with the exact number of arguments in xs

* Move examples to folder 19_elementwise

* Add missing copyright headers and fix some existing ones

* Replace an assert with throw std::runtime_error in elementwise example

* Avoid reading the output by using make_static_distributed_tensor for y_tile

* Removed two unused includes

* No need to move windows to the next block when each workgroup processes a single tile

* Only copy input tensors to the device

* Use get_warp_size to obtain warp size, and use ceiling division for grid size also for the unary example

* Adding output strides to the kernel, transposition example and update the other examples

* Changes made by remod.py

* Use default template parameter values for memory operation and coherence in a call to make_naive_tensor_view

* Move binary operations to include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp

* Reuse generic reference binary/unary operation in examples + refactoring the transpose reference

* Fix comments in elementwise_example.cpp

- Refer to AMD terminology except when suggesting NVIDIA alternatives in parentheses
- ElementWiseTraits was renamed to ElementWiseShape
- Adopt suggestions made by Copilot when prompted to check for factual or typographical errors

* Simplify CMakeLists.txt and remove the unused variables this uncovers

* Rename a file and fix some copyright statements

* Changes made by script/clang-format-overwrite.sh

* Add basic unit test for ElementWiseKernel

* Remove left-over uninformative comment in apply unit test

* Changes made by clang-format-overwrite.sh

* fixup! Use default template parameter values for memory operation and coherence in a call to make_naive_tensor_view

* Clean up test_tuple_apply.cpp and test_elementwise_1d.cpp

* Use make_uniform_array_with_factory to define h_xs and d_xs_mems_owner as type std::array

* Use a DeviceMem constructor that calls get_element_space_size_in_bytes internally

* Move examples to folder 20_elementwise

* Reduced register pressure on the CK tile elementwise kernel + add 4d input example to be able benchmark against old CK

* Fix CLang formating

* Bump up the elementwise example folder number

* Elementwise: add padding + minor cleanup

* Add Vector Size inference + fix issue with wrong vectorization due to missing GuaranteedLastDimensionVectorStride setting in make_naive_tensor_view

* Add isSupportedArg to Elementwise kernel + addapt example and unit tests

* Fix clang-format on the unit test file

---------

Co-authored-by: Damien Lejeune <damien.lejeune@amd.com>
Co-authored-by: Sami Aario <samaario@amd.com>
Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: Aviral Goel <aviral.goel@amd.com>

[ROCm/composable_kernel commit: 606b0cc947]
This commit is contained in:
Yashvardhan Agarwal
2025-07-24 12:21:45 +03:00
committed by GitHub
parent bdb86fee78
commit 094e5bad50
23 changed files with 1509 additions and 6 deletions

View File

@@ -23,6 +23,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
* Added Ping-pong scheduler support for GEMM operation along the K dimension.
* Added rotating buffer feature for CK_Tile GEMM.
* Added int8 support for CK_TILE GEMM.
* Added support for elementwise kernel.
### Optimized

View File

@@ -0,0 +1,15 @@
# Elementwise example targets 2D inputs
set(TARGET_NAME_2D_INPUT tile_example_elementwise)
add_executable(${TARGET_NAME_2D_INPUT} elementwise_example.cpp)
# Elementwise unary example targets 2D inputs
set(TARGET_NAME_2D_INPUT_UNARY tile_example_elementwise_unary)
add_executable(${TARGET_NAME_2D_INPUT_UNARY} elementwise_example_unary.cpp)
# Elementwise transpose example targets 2D inputs
set(TARGET_NAME_2D_INPUT_TRANSPOSE tile_example_elementwise_transpose)
add_executable(${TARGET_NAME_2D_INPUT_TRANSPOSE} elementwise_example_transpose.cpp)
# Elementwise example targets 4D inputs
set(TARGET_NAME_4D_INPUT tile_example_elementwise_add_4d)
add_executable(${TARGET_NAME_4D_INPUT} elementwise_example_add_4d.cpp)

View File

@@ -0,0 +1,214 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "1024", "m dimension")
.insert("n", "1024", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "10", "cold iter")
.insert("repeat", "50", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
// If stride is negative (default -1), set it to N, assuming a dense row-major layout.
if(stride < 0)
stride = N;
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
if(stride < N)
{
throw std::runtime_error("stride must be >= N");
}
// Define type aliases for clarity.
// XDataType: Data type of the input tensors.
// ComputeDataType: Data type used for intermediate computations (often float for precision).
// YDataType: Data type of the output tensor.
// XElementwiseOperation: The specific elementwise operation to perform (e.g., Add, Mul).
using XDataType = DataType;
using ComputeDataType =
float; // Using float for intermediate calculations can improve numerical stability.
using YDataType = DataType;
using XElementwiseOperation = ck_tile::element_wise::Add;
// 1. Initialize the input data on the host (CPU).
// HostTensor is a utility to manage tensor data on the CPU.
// The first argument is the shape (dimensions) of the tensor {M, N}.
// The second argument is the strides {stride, 1} for row-major layout.
// 'x_host_a' and 'x_host_b' are the two input tensors for the elementwise operation.
ck_tile::HostTensor<XDataType> x_host_a({M, N}, {stride, 1});
ck_tile::HostTensor<XDataType> x_host_b({M, N}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host({M, N}, {stride, 1});
ck_tile::HostTensor<YDataType> y_validation({M, N}, {stride, 1});
std::vector<ck_tile::index_t> shape = {M, N};
// Fill the host tensors with random data.
// FillUniformDistribution populates the tensor with values from a uniform distribution,
// within an interval.
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_a);
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_b);
// 2. Create device memory buffers
// DeviceMem allocates memory on the GPU.
// The size is determined by the total number of elements and the size of DataType.
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes());
// Copy data from host input tensors to device buffers.
x_buf_a.ToDevice(x_host_a.data());
x_buf_b.ToDevice(x_host_b.data());
// 3. Configure the kernel execution parameters.
// Dividing the problem into blocktile, blockwarp and warptile
// The blocktile is the size of the tile processed by a single work group (also called thread
// block). The warptile is the size of the tile processed by a single wavefront (also called
// warp). The vector is the size of the tile processed by a single work item (also called
// thread). The problem is divided into blocks of size BlockTile. Each block is further divided
// into wavefronts of size WarpTile. Each wavefront is composed of 64 work items (on AMD; 32
// threads on NVIDIA). Each work item in a wavefront processes one vector's worth of elements.
// Note that WarpTile/Vector should be 64 for CDNA (because there are 64 work items per
// wavefront). Vector size is set to be 16 / sizeof(ComputeDataType), to maximize vectorization.
using BlockTile = ck_tile::sequence<2048>; // How many elements are handled by a block tile (the
// tensor is divided into blocks of this size)
using BlockWarps = ck_tile::sequence<8>; // How many concurrent wavefronts are in a block (each
// wavefront will cover some part of the block tile)
// WarpTile: Defines the size of the data sub-tile processed by a single wavefront.
// This should be consistent with BlockTile and BlockWarps.
// If BlockTile is 2048 and BlockWarps is 8, then WarpTile could be 2048/8 = 256.
// However, this example uses 64, meaning each wavefront processes 64 elements, and multiple
// such wavefront operations might be needed to cover the BlockTile, or the BlockTile is
// distributed differently.
// The current configuration (BlockTile=2048, BlockWarps=8, WarpTile=64) implies that
// each wavefront processes 64 elements, and 8 wavefronts process 8*64 = 512 elements
// concurrently. Since 512 is not equal to 2048, it means that warptile(s) will need to iterate
// over multiple times over different set of elements to cover the entire BlockTile.
using WarpTile = ck_tile::sequence<64>;
// 4. Create the kernel
// ElementWiseShape bundles these tiling parameters.
// It calculates derived properties like threads per wavefront, repeats, vectorization and total
// block size.
using Shape = ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, ComputeDataType>;
// ElementWisePipelineProblem encapsulates all necessary information for the elementwise kernel:
// - Data types (input, compute, output).
// - Shape traits (tiling configuration).
// - The specific elementwise operation (e.g., Add).
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
ComputeDataType,
YDataType,
Shape,
XElementwiseOperation>;
// ElementWiseKernel refers to the GPU kernel class
using Kernel = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
// Compute flattened size
ck_tile::index_t total_elements = 1;
for(auto d : shape)
total_elements *= d;
// kBlockSize: The number of work items in a GPU workgroup (thread block).
// This is often a multiple of the wavefront size, 64 on CDNA.
// Here, it's explicitly set to 512. This should be consistent with Shape::kBlockSize.
// Shape::kBlockSize would be BlockWarps * warpSize (e.g., 8 * 64 = 512).
constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
// kBlockPerCu: Hint for how many workgroups can be scheduled per Compute Unit (CU).
// This can influence occupancy and performance.
constexpr ck_tile::index_t kBlockPerCu = 1;
// kGridSize: Calculates the total number of workgroups required to process all elements.
// Each workgroup is responsible for 'elements_per_block' elements.
// To ensure all elements are covered, especially when 'total_elements' is not perfectly
// divisible by 'elements_per_block', using ceiling division.
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
std::cout << "grid size = " << kGridSize << std::endl;
std::cout << "Total elements = " << total_elements << std::endl;
auto input_tensors = ck_tile::make_tuple(static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()),
static_cast<XDataType*>(x_buf_b.GetDeviceBuffer()));
auto input_size = ck_tile::make_tuple(M, N);
// Check if the kernel configuration is supported
if(!Kernel::IsSupportedArgument(input_size))
{
throw std::runtime_error(
"The kernel configuration is not supported for the given input size.");
}
// 4. Run the kernel
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(N, 1), // Input Stride
ck_tile::make_tuple(N, 1), // Output Stride
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
std::cout << "Average time: " << ave_time << " ms" << std::endl;
// 5. Verify the output
bool pass = true;
if(do_validation)
{
y_buf.FromDevice(y_validation.data());
auto op = [](const auto& v0, const auto& v1) { return v0 + v1; };
ck_tile::reference_binary_elementwise<XDataType, XDataType, YDataType, ComputeDataType>(
x_host_a, x_host_b, y_host, op);
pass = ck_tile::check_err(
y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01);
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
return -3;
}

View File

@@ -0,0 +1,159 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("dim0", "4", "dimension 0")
.insert("dim1", "16", "dimension 1")
.insert("dim2", "32", "dimension 2")
.insert("dim3", "32", "dimension 3")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "10", "cold iter")
.insert("repeat", "50", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t D0 = arg_parser.get_int("dim0");
ck_tile::index_t D1 = arg_parser.get_int("dim1");
ck_tile::index_t D2 = arg_parser.get_int("dim2");
ck_tile::index_t D3 = arg_parser.get_int("dim3");
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
using XDataType = DataType;
using ComputeDataType =
float; // Using float for intermediate calculations can improve numerical stability.
using YDataType = DataType;
using XElementwiseOperation = ck_tile::element_wise::Add;
// Initialize the input data on the host (CPU).
std::vector<ck_tile::index_t> problem_shape = {D0, D1, D2, D3};
std::vector<ck_tile::index_t> host_strides(4);
host_strides[3] = 1;
host_strides[2] = problem_shape[3];
host_strides[1] = problem_shape[2] * problem_shape[3];
host_strides[0] = problem_shape[1] * problem_shape[2] * problem_shape[3];
ck_tile::HostTensor<XDataType> x_host_a(problem_shape, host_strides);
ck_tile::HostTensor<XDataType> x_host_b(problem_shape, host_strides);
ck_tile::HostTensor<YDataType> y_host(problem_shape, host_strides);
ck_tile::HostTensor<YDataType> y_validation(problem_shape, host_strides);
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_a);
ck_tile::FillUniformDistribution<XDataType>{2.f, 10.f}(x_host_b);
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes());
x_buf_a.ToDevice(x_host_a.data());
x_buf_b.ToDevice(x_host_b.data());
using BlockTile = ck_tile::sequence<256>;
using BlockWarps = ck_tile::sequence<1>;
using WarpTile = ck_tile::sequence<256>;
using Shape = ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, ComputeDataType>;
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
ComputeDataType,
YDataType,
Shape,
XElementwiseOperation>;
using Kernel = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
ck_tile::index_t total_elements = 1;
for(auto d : problem_shape)
total_elements *= d;
constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
constexpr ck_tile::index_t kBlockPerCu = 2;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
std::cout << "grid size = " << kGridSize << std::endl;
std::cout << "Total elements = " << total_elements << std::endl;
auto input_tensors = ck_tile::make_tuple(static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()),
static_cast<XDataType*>(x_buf_b.GetDeviceBuffer()));
auto problem_shape_tuple =
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
auto strides_tuple =
ck_tile::make_tuple(host_strides[0], host_strides[1], host_strides[2], host_strides[3]);
// Check if the kernel configuration is supported
if(!Kernel::IsSupportedArgument(problem_shape_tuple))
{
throw std::runtime_error(
"The kernel configuration is not supported for the given input size.");
}
// Run the kernel
float ave_time = launch_kernel(
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
problem_shape_tuple, // ck_tile::tuple<index_t, index_t, index_t, index_t>
strides_tuple, // ck_tile::tuple<index_t, index_t, index_t, index_t> for input strides
strides_tuple, // ck_tile::tuple<index_t, index_t, index_t, index_t> for output strides
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
std::cout << "Average time: " << ave_time << " ms" << std::endl;
// Verify the output
bool pass = true;
if(do_validation)
{
y_buf.FromDevice(y_validation.data());
auto op = [](const auto& v0, const auto& v1) { return v0 + v1; };
ck_tile::reference_binary_elementwise<XDataType, XDataType, YDataType, ComputeDataType>(
x_host_a, x_host_b, y_host, op);
pass = ck_tile::check_err(
y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01);
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
return -3;
}

View File

@@ -0,0 +1,156 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/host.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/reference/reference_transpose.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "1024", "m dimension of input")
.insert("n", "1024", "n dimension of input")
.insert("stride_in", "-1", "stride for input M dim, if -1 then equal to n")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "10", "cold iter")
.insert("repeat", "50", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t stride_in = arg_parser.get_int("stride_in");
if(stride_in < 0)
stride_in = N; // Dense input: stride for M dim is N
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
if(stride_in < N)
{
throw std::runtime_error("stride_in must be >= N");
}
using XDataType = DataType;
using ComputeDataType = float;
using YDataType = DataType;
// Use PassThrough operation for transposition (data is moved, not changed)
using XElementwiseOperation = ck_tile::element_wise::PassThrough;
// 1. Initialize the input data on the host (CPU).
// Input x_host_a: M x N
// Output y_host: N x M (transposed)
ck_tile::HostTensor<XDataType> x_host_a({M, N}, {stride_in, 1});
// Output tensor y_host will have dimensions N x M.
// Assuming dense output, its stride for the N dimension will be M.
ck_tile::index_t stride_out_dim0 = M;
ck_tile::HostTensor<YDataType> y_host({N, M}, {stride_out_dim0, 1});
ck_tile::HostTensor<YDataType> y_validation({N, M}, {stride_out_dim0, 1});
// The logical shape for the element-wise operation kernel is based on the input tensor's
// elements.
std::vector<ck_tile::index_t> op_shape_vec = {M, N};
auto op_lengths = ck_tile::make_tuple(M, N); // Lens for the kernel
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_a);
// 2. Create device memory buffers
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); // y_host is N x M
x_buf_a.ToDevice(x_host_a.data());
// 3. Configure the kernel execution parameters.
using BlockTile = ck_tile::sequence<1024>;
using BlockWarps = ck_tile::sequence<8>;
using WarpTile = ck_tile::sequence<64>;
using Shape = ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, ComputeDataType>;
// Problem definition for a single input tensor
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
ComputeDataType,
YDataType,
Shape,
XElementwiseOperation>;
using Kernel = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
ck_tile::index_t total_elements = M * N;
constexpr ck_tile::index_t kBlockSize = 64 * BlockWarps::at(ck_tile::number<0>{});
constexpr ck_tile::index_t kBlockPerCu = 1;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
std::cout << "Input M=" << M << ", N=" << N << ", StrideIn=" << stride_in << std::endl;
std::cout << "Output N=" << N << ", M=" << M << ", StrideOut=" << stride_out_dim0 << std::endl;
std::cout << "Grid size = " << kGridSize << ", BlockSize = " << kBlockSize << std::endl;
std::cout << "Total elements = " << total_elements << std::endl;
// Input tensors tuple (single input)
auto input_tensors = ck_tile::make_tuple(static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()));
// Input strides tuple (tuple of tuples, one for each input)
auto input_strides = ck_tile::make_tuple(stride_in, 1);
// Output strides (for N x M tensor, dense)
auto output_strides = ck_tile::make_tuple(1, stride_out_dim0);
// Check if the kernel configuration is supported
if(!Kernel::IsSupportedArgument(op_lengths))
{
throw std::runtime_error(
"The kernel configuration is not supported for the given input size.");
}
// 4. Run the kernel
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0, // Shared memory
op_lengths, // Logical dimensions for the operation (M, N)
input_strides, // Strides for input tensor(s)
output_strides, // Strides for output tensor (N, M)
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
std::cout << "Average time: " << ave_time << " ms" << std::endl;
// 5. Verify the output
bool pass = true;
if(do_validation)
{
y_buf.FromDevice(y_validation.data()); // Copy result from device to y_validation
ck_tile::reference_transpose_elementwise<XDataType, YDataType>(
x_host_a, y_host); // Compute reference on host
pass = ck_tile::check_err(
y_validation, y_host, "Transpose Error: Incorrect results!", 0.01, 0.01);
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
std::cerr << "Unsupported data type: " << data_type << std::endl;
return -3;
}

View File

@@ -0,0 +1,147 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
auto create_args(int argc, char* argv[])
{
ck_tile::ArgParser arg_parser;
arg_parser.insert("m", "1024", "m dimension")
.insert("n", "1024", "n dimension")
.insert("stride", "-1", "stride per row, if -1 then equal to n")
.insert("v", "1", "cpu validation or not")
.insert("prec", "fp16", "precision")
.insert("warmup", "10", "cold iter")
.insert("repeat", "50", "hot iter");
bool result = arg_parser.parse(argc, argv);
return std::make_tuple(result, arg_parser);
}
template <typename DataType>
bool run(const ck_tile::ArgParser& arg_parser)
{
ck_tile::index_t M = arg_parser.get_int("m");
ck_tile::index_t N = arg_parser.get_int("n");
ck_tile::index_t stride = arg_parser.get_int("stride");
if(stride < 0)
stride = N;
std::string data_type = arg_parser.get_str("prec");
int do_validation = arg_parser.get_int("v");
int warmup = arg_parser.get_int("warmup");
int repeat = arg_parser.get_int("repeat");
assert(stride >= N);
using XDataType = DataType;
using YDataType = DataType;
using ComputeDataType = float;
using XElementwiseOperation = ck_tile::element_wise::UnarySquare;
// 1. Initialize the input data on the host
ck_tile::HostTensor<XDataType> x_host_a({M, N}, {stride, 1});
ck_tile::HostTensor<YDataType> y_host({M, N}, {stride, 1});
ck_tile::HostTensor<YDataType> y_validation({M, N}, {stride, 1});
std::vector<ck_tile::index_t> shape = {M, N};
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_a);
// 2. Create device memory buffers and copy input data from host to device
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes());
x_buf_a.ToDevice(x_host_a.data());
// 3. Create the kernel
// Dividing the problem into blocktile, warptile, and vector
using BlockTile = ck_tile::sequence<2048>; // Size of the block tile (Entire problem is divided
// into blocks of this size)
using BlockWarps = ck_tile::sequence<8>; // How many concurrent warps are in a block (Each warp
// will cover some part of blockTile)
using WarpTile = ck_tile::sequence<64>; // How many elements are covered by a warp
using Shape = ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, ComputeDataType>;
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
XDataType, // ComputeDataType is same as
// XDataType in the unary case
YDataType,
Shape,
XElementwiseOperation>;
using Kernel = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
// Compute flattened size
ck_tile::index_t total_elements = 1;
for(auto d : shape)
total_elements *= d;
constexpr ck_tile::index_t kBlockSize =
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
constexpr ck_tile::index_t kBlockPerCu = 1;
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
std::cout << "grid size = " << kGridSize << std::endl;
std::cout << "Total elements = " << total_elements << std::endl;
auto input_tensors = ck_tile::make_tuple(static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()));
auto input_size = ck_tile::make_tuple(M, N);
// Check if the kernel configuration is supported
if(!Kernel::IsSupportedArgument(input_size))
{
throw std::runtime_error(
"The kernel configuration is not supported for the given input size.");
}
// 4. Run the kernel
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
Kernel{},
kGridSize,
kBlockSize,
0,
input_size,
ck_tile::make_tuple(N, 1), // Input Stride
ck_tile::make_tuple(N, 1), // Output Stride
input_tensors,
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
std::cout << "Average time: " << ave_time << " ms" << std::endl;
// 5. Verify the output
bool pass = true;
if(do_validation)
{
y_buf.FromDevice(y_validation.data());
auto op = [](const auto& v0) { return v0 * v0; };
ck_tile::reference_unary_elementwise<XDataType, YDataType, YDataType>(x_host_a, y_host, op);
pass = ck_tile::check_err(
y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01);
}
return pass;
}
int main(int argc, char* argv[])
{
auto [result, arg_parser] = create_args(argc, argv);
if(!result)
return -1;
const std::string data_type = arg_parser.get_str("prec");
if(data_type == "fp16")
{
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
}
return -3;
}

View File

@@ -20,6 +20,7 @@ add_subdirectory(17_grouped_gemm)
add_subdirectory(18_flatmm)
add_subdirectory(19_gemm_multi_d)
add_subdirectory(20_grouped_convolution)
add_subdirectory(21_elementwise)
add_subdirectory(35_batched_transpose)
add_subdirectory(37_transpose)
add_subdirectory(38_block_scale_gemm)

View File

@@ -264,10 +264,14 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
// clang-format off
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const & { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const & { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() & { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) & { TP_COM_(); return get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) && { TP_COM_(); return std::move(*this).template get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const &&{ TP_COM_(); return std::move(*this).template get<I>(); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv<I>(*this); }
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) const { TP_COM_(); return get<I>(); }
@@ -470,6 +474,12 @@ transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
}
template <typename F, typename Tuple, index_t... Is>
constexpr decltype(auto) apply_impl(F&& f, Tuple&& t, sequence<Is...>)
{
return std::forward<F>(f)(std::forward<Tuple>(t).get(number<Is>{})...);
}
} // namespace detail
template <typename F, typename X>
@@ -493,6 +503,13 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
}
template <typename F, typename Tuple>
constexpr decltype(auto) apply(F&& f, Tuple&& t)
{
constexpr index_t N = std::decay_t<Tuple>::size();
return detail::apply_impl(std::forward<F>(f), std::forward<Tuple>(t), make_index_sequence<N>{});
}
namespace detail {
template <typename F, typename X, index_t... Is>

View File

@@ -38,6 +38,7 @@
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/reference/reference_transpose.hpp"
#include "ck_tile/host/rotating_buffers.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_utils.hpp"

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -0,0 +1,33 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace ck_tile {
template <typename ADataType, typename BDataType>
void reference_transpose_elementwise(const HostTensor<ADataType>& a, HostTensor<BDataType>& b)
{
ck_tile::index_t M = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[0]);
ck_tile::index_t N = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[1]);
// Ensure the b tensor is sized correctly for N x M
if(static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[0]) != N ||
static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[1]) != M)
{
throw std::runtime_error("Output tensor b has incorrect dimensions for transpose.");
}
auto f = [&](auto i, auto j) {
auto v_a = a(i, j);
b(j, i) = ck_tile::type_convert<BDataType>(v_a);
};
make_ParallelTensorFunctor(f, M, N)(std::thread::hardware_concurrency());
}
} // namespace ck_tile

View File

@@ -3,6 +3,11 @@
#pragma once
#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp"
#include "ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"

View File

@@ -0,0 +1,94 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
namespace element_wise {
struct Add
{
template <typename Y, typename X0, typename X1>
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const float& x1) const
{
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<double>(double& y, const double& x0, const double& x1) const
{
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const half_t& x1) const
{
y = x0 + type_convert<half_t>(x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
{
y = type_convert<half_t>(x0 + x1);
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
{
y = type_convert<half_t>(x0) + x1;
};
template <>
__host__ __device__ constexpr void
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
{
y = x0 + x1;
};
template <>
__host__ __device__ constexpr void
operator()<float>(float& y, const float& x0, const bf16_t& x1) const
{
const float x1_tmp = type_convert<float>(x1);
y = x0 + x1_tmp;
}
template <>
__host__ __device__ constexpr void
operator()<bf16_t>(bf16_t& y, const bf16_t& x0, const bf16_t& x1) const
{
const float x1_tmp = type_convert<float>(x0);
const float x2_tmp = type_convert<float>(x1);
const float y_tmp = x1_tmp + x2_tmp;
y = type_convert<bf16_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<bf16_t>(bf16_t& y, const float& x0, const bf16_t& x1) const
{
const float x2_tmp = type_convert<float>(x1);
const float y_tmp = x0 + x2_tmp;
y = type_convert<bf16_t>(y_tmp);
}
template <>
__host__ __device__ constexpr void
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
{
y = x0 + x1;
};
};
} // namespace element_wise
} // namespace ck_tile

View File

@@ -0,0 +1,123 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_>
struct ElementWiseKernel
{
using Problem = ck_tile::remove_cvref_t<Problem_>;
using Policy = ck_tile::remove_cvref_t<Policy_>;
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
using ElementWiseOperation = ck_tile::remove_cvref_t<typename Problem::ElementWiseOperation>;
template <typename... XDataType, typename Dims>
CK_TILE_DEVICE void operator()(Dims lens,
Dims input_strides,
Dims output_strides,
const tuple<XDataType...>& input_tensors,
YDataType* p_y) const
{
using S = typename Problem::BlockShape;
// Setup block-level coordinates and transforms
const index_t iM = get_block_id() * S::kBlockM;
const auto merge_transform = make_merge_transform(lens);
// Load all input tiles into registers.
// The lambda structure here is intended to minimize the lifetime
// of intermediate objects (views, windows) used for loading.
const auto x_tiles = ck_tile::generate_tuple(
[&](auto i) {
const auto tensor_view = make_naive_tensor_view<address_space_enum::global>(
input_tensors.get(i), lens, input_strides, number<S::kVectorM>{}, number<1>{});
const auto transformed_tensor = pad_tensor_view(
transform_tensor_view(tensor_view,
ck_tile::make_tuple(merge_transform),
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
ck_tile::make_tuple(sequence<0>{})),
ck_tile::make_tuple(number<S::kBlockM>{}),
sequence<Problem::kPad>{});
const auto x_window =
make_tile_window(transformed_tensor,
ck_tile::make_tuple(number<S::kBlockM>{}),
{iM},
Policy::template MakeXBlockTileDistribution<Problem>());
return load_tile(x_window);
},
number<sizeof...(XDataType)>{});
// Setup output tile in registers.
const auto& x_tile0 = x_tiles.get(number<0>{});
auto y_tile = make_static_distributed_tensor<YDataType>(x_tile0.get_tile_distribution());
// Perform element-wise computation.
const auto spans = x_tile0.get_distributed_spans();
sweep_tile_span(spans[number<0>{}], [&](auto idx) {
const auto tile_idx = make_tuple(idx);
apply(
[&](auto&&... tiles) {
ElementWiseOperation{}(y_tile(tile_idx),
type_convert<ComputeDataType>(tiles[tile_idx])...);
},
x_tiles);
});
// Setup output window and store the result tile.
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
p_y, lens, output_strides, number<S::kVectorM>{});
const auto transformed_y_m_n = pad_tensor_view(
transform_tensor_view(y_m_n,
ck_tile::make_tuple(merge_transform),
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
ck_tile::make_tuple(sequence<0>{})),
ck_tile::make_tuple(number<S::kBlockM>{}),
sequence<Problem::kPad>{});
auto y_window = make_tile_window(transformed_y_m_n,
make_tuple(number<S::kBlockM>{}),
{iM},
y_tile.get_tile_distribution());
store_tile(y_window, cast_tile<YDataType>(y_tile));
}
template <typename... Ints>
CK_TILE_HOST static bool IsSupportedArgument(const ck_tile::tuple<Ints...>& input_sizes)
{
int total_elements = 1;
const auto kVectorM = Problem_::BlockShape::kVectorM;
apply([&](auto&&... args) { ((total_elements *= args), ...); }, input_sizes);
if((total_elements % kVectorM) != 0)
{
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
{
CK_TILE_ERROR("Conditions not met: total number of input elements (",
total_elements,
") should be multiple of the vectorization size (",
kVectorM,
")");
}
return false;
}
return true;
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,29 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
struct ElementWiseDefaultPolicy
{
template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
{
using S = typename Problem::BlockShape;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, // Replicate
tuple<sequence<S::kRepeatM,
S::kWarpPerBlockM,
S::kThreadPerWarpM,
S::kVectorM>>, // Hierarchical
tuple<sequence<1>, sequence<1>>, // Parallel
tuple<sequence<1>, sequence<2>>, // Parallel
sequence<1, 1>, // Yield
sequence<0, 3>>{} // Yield
);
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,26 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename XDataType_,
typename ComputeDataType_,
typename YDataType_,
typename BlockShape_,
typename ElementWiseOperation_,
bool kPad_ = true>
struct ElementWisePipelineProblem
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using ElementWiseOperation = remove_cvref_t<ElementWiseOperation_>;
static constexpr bool kPad = kPad_;
};
} // namespace ck_tile

View File

@@ -0,0 +1,29 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace ck_tile {
template <typename BlockWarps, typename BlockTile, typename WarpTile, typename ComputeDataType>
struct ElementWiseShape
{
static constexpr index_t kBlockM = BlockTile::at(number<0>{});
static constexpr index_t kWarpM = WarpTile::at(number<0>{});
static constexpr index_t kVectorM = 16 / sizeof(ComputeDataType);
static constexpr index_t kWarpPerBlockM = BlockWarps::at(number<0>{});
static constexpr index_t kThreadPerWarpM = kWarpM / kVectorM;
static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kWarpM);
static constexpr index_t kBlockSize =
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
};
} // namespace ck_tile

View File

@@ -1,5 +1,5 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once

View File

@@ -5,6 +5,8 @@ add_subdirectory(batched_gemm)
add_subdirectory(grouped_gemm)
add_subdirectory(gemm_multi_d)
add_subdirectory(data_type)
add_subdirectory(container)
add_subdirectory(elementwise)
# Not including these tests as there is a bug on gfx90a and gfx942
# resulting in "GPU core dump"
#add_subdirectory(moe_smoothquant)

View File

@@ -0,0 +1,6 @@
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_tuple_apply test_tuple_apply.cpp)
if(result EQUAL 0)
target_link_libraries(test_ck_tile_tuple_apply PRIVATE utility)
endif()
endif()

View File

@@ -0,0 +1,223 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include "ck_tile/core.hpp"
using namespace ck_tile;
class TestCkTileTupleApply : public ::testing::Test
{
public:
// Test functors for different scenarios
struct AddFunction
{
template <typename... Args>
CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const
{
return (args + ...);
}
};
struct MultiplyFunction
{
template <typename... Args>
CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const
{
return (args * ...);
}
};
struct MaxFunction
{
template <typename T>
CK_TILE_HOST_DEVICE constexpr T operator()(T a) const
{
return a;
}
template <typename T, typename... Args>
CK_TILE_HOST_DEVICE constexpr T operator()(T a, Args... args) const
{
auto rest_max = operator()(args...);
return a > rest_max ? a : rest_max;
}
};
struct ReturnTupleFunction
{
template <typename... Args>
CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const
{
return make_tuple(args..., sizeof...(args));
}
};
};
TEST_F(TestCkTileTupleApply, BasicArithmetic)
{
// Test with simple arithmetic operations
auto t1 = make_tuple(1, 2, 3);
auto result1 = apply(AddFunction{}, t1);
EXPECT_EQ(result1, 6);
auto t2 = make_tuple(2, 3, 4, 5);
auto result2 = apply(MultiplyFunction{}, t2);
EXPECT_EQ(result2, 120);
}
TEST_F(TestCkTileTupleApply, SingleElement)
{
// Test with single element tuple
auto t1 = make_tuple(42);
auto result1 = apply(AddFunction{}, t1);
EXPECT_EQ(result1, 42);
auto result2 = apply(MultiplyFunction{}, t1);
EXPECT_EQ(result2, 42);
}
TEST_F(TestCkTileTupleApply, EmptyTuple)
{
// Test with empty tuple
auto t = tuple<>{};
auto result = apply([]() { return 100; }, t);
EXPECT_EQ(result, 100);
}
TEST_F(TestCkTileTupleApply, DifferentTypes)
{
// Test with different data types
auto t1 = make_tuple(1, 2.5f, 3.0);
auto result1 = apply(AddFunction{}, t1);
EXPECT_FLOAT_EQ(result1, 6.5f);
// Test with mixed integer and floating point
auto t2 = make_tuple(10, 0.5f);
auto result2 = apply(MultiplyFunction{}, t2);
EXPECT_FLOAT_EQ(result2, 5.0f);
}
TEST_F(TestCkTileTupleApply, ReturnTuple)
{
// Test function that returns a tuple
auto t = make_tuple(1, 2, 3);
auto result = apply(ReturnTupleFunction{}, t);
EXPECT_EQ(result.get<0>(), 1);
EXPECT_EQ(result.get<1>(), 2);
EXPECT_EQ(result.get<2>(), 3);
EXPECT_EQ(result.get<3>(), 3); // size
}
TEST_F(TestCkTileTupleApply, LambdaFunction)
{
// Test with lambda functions
auto t1 = make_tuple(5, 10, 15);
auto result1 = apply([](auto a, auto b, auto c) { return a + b + c; }, t1);
EXPECT_EQ(result1, 30);
// Test lambda with capture
int multiplier = 2;
auto result2 =
apply([multiplier](auto a, auto b) { return (a + b) * multiplier; }, make_tuple(3, 7));
EXPECT_EQ(result2, 20);
}
TEST_F(TestCkTileTupleApply, ConstexprContext)
{
// Test in constexpr context
constexpr auto t = make_tuple(2, 3, 4);
constexpr auto result = apply(MultiplyFunction{}, t);
static_assert(result == 24, "Constexpr apply should work");
EXPECT_EQ(result, 24);
}
TEST_F(TestCkTileTupleApply, ReferenceTypes)
{
// Test with reference types using tie
int a = 1, b = 2, c = 3;
auto ref_tuple = tie(a, b, c);
// Function that modifies references
apply(
[](auto& x, auto& y, auto& z) {
x += 10;
y += 20;
z += 30;
},
ref_tuple);
EXPECT_EQ(a, 11);
EXPECT_EQ(b, 22);
EXPECT_EQ(c, 33);
}
TEST_F(TestCkTileTupleApply, MoveSemantics)
{
// Test with move semantics
auto t = make_tuple(1, 2, 3);
auto result = apply(AddFunction{}, std::move(t));
EXPECT_EQ(result, 6);
}
TEST_F(TestCkTileTupleApply, NumberTypes)
{
// Test with ck_tile::number types
auto t = make_tuple(number<1>{}, number<2>{}, number<3>{});
auto result = apply([](auto a, auto b, auto c) { return a + b + c; }, t);
EXPECT_EQ(result, 6);
}
TEST_F(TestCkTileTupleApply, ElementwiseOperation)
{
// Test simulating elementwise operations
auto input1 = make_tuple(1.0f, 2.0f, 3.0f);
auto input2 = make_tuple(4.0f, 5.0f, 6.0f);
auto add_elementwise = [](const auto& a, const auto& b) {
return apply(
[&b](auto... args_a) {
return apply(
[args_a...](auto... args_b) { return make_tuple((args_a + args_b)...); }, b);
},
a);
};
auto result = add_elementwise(input1, input2);
EXPECT_FLOAT_EQ(result.get<0>(), 5.0f);
EXPECT_FLOAT_EQ(result.get<1>(), 7.0f);
EXPECT_FLOAT_EQ(result.get<2>(), 9.0f);
}
template <typename T>
class TestCkTileTupleApplySize : public TestCkTileTupleApply
{
protected:
static constexpr int Size = T::value;
};
using TupleSizes = ::testing::Types<std::integral_constant<int, 1>,
std::integral_constant<int, 2>,
std::integral_constant<int, 3>,
std::integral_constant<int, 4>,
std::integral_constant<int, 8>,
std::integral_constant<int, 16>>;
TYPED_TEST_SUITE(TestCkTileTupleApplySize, TupleSizes);
TYPED_TEST(TestCkTileTupleApplySize, GeneratedTupleSum)
{
constexpr int N = TypeParam::value;
// Generate tuple with values 1, 2, 3, ..., N
constexpr auto t = generate_tuple([](auto i) { return i.value + 1; }, number<N>{});
// Sum all elements
constexpr auto result = apply(TestCkTileTupleApply::AddFunction{}, t);
// Expected sum: 1 + 2 + ... + N = N*(N+1)/2
constexpr int expected = N * (N + 1) / 2;
static_assert(result == expected);
}

View File

@@ -0,0 +1,6 @@
if(GPU_TARGETS MATCHES "gfx9")
add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp)
if(result EQUAL 0)
target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility)
endif()
endif()

View File

@@ -0,0 +1,216 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#include <gtest/gtest.h>
#include <vector>
#include <cmath> // For std::abs
#include <tuple>
#include <type_traits> // For std::is_same_v, std::is_floating_point_v
#include <utility> // For std::index_sequence, std::forward
#include "ck_tile/core.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp"
#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
// Traits to get number of inputs for an elementwise operation
template <typename Op>
struct elementwise_op_traits;
template <>
struct elementwise_op_traits<ck_tile::element_wise::Add>
{
static constexpr int num_inputs = 2;
};
template <>
struct elementwise_op_traits<ck_tile::element_wise::Relu>
{
static constexpr int num_inputs = 1;
};
template <std::size_t D, typename F>
auto make_uniform_array_with_factory(F&& factory)
{
return [&]<std::size_t... Is>(std::index_sequence<Is...>)
{
return std::array<std::invoke_result_t<F, std::size_t>, D>{factory(Is)...};
}
(std::make_index_sequence<D>{});
}
template <typename Tuple>
class TestCkTileElementwise : public ::testing::Test
{
protected:
using XDataType = std::tuple_element_t<0, Tuple>;
using YDataType = std::tuple_element_t<1, Tuple>;
using ComputeDataType = std::tuple_element_t<2, Tuple>;
using ElementwiseOpType = std::tuple_element_t<3, Tuple>;
using BlockWarps_ = std::tuple_element_t<4, Tuple>;
using BlockTile_ = std::tuple_element_t<5, Tuple>;
using WarpTile_ = std::tuple_element_t<6, Tuple>;
using TestElementWiseShape =
ck_tile::ElementWiseShape<BlockWarps_, BlockTile_, WarpTile_, ComputeDataType>;
static constexpr int NumInputs = elementwise_op_traits<ElementwiseOpType>::num_inputs;
void RunTest(ck_tile::index_t total_m_elements)
{
// Dims and Strides (1D example)
auto lens = ck_tile::make_tuple(total_m_elements);
auto strides = ck_tile::make_tuple(
static_cast<ck_tile::index_t>(1)); // Strides for the single dimension
// Host Tensors
auto h_xs = make_uniform_array_with_factory<NumInputs>([&](std::size_t) {
auto ret = ck_tile::HostTensor<XDataType>({total_m_elements});
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(ret);
return ret;
});
ck_tile::HostTensor<YDataType> h_y({total_m_elements});
h_y.SetZero();
ck_tile::HostTensor<YDataType> h_y_ref({total_m_elements});
h_y_ref.SetZero();
// Device Buffers
auto d_xs_mems_owner = make_uniform_array_with_factory<NumInputs>(
[&](std::size_t i) { return ck_tile::DeviceMem(h_xs[i]); });
for(int i = 0; i < NumInputs; ++i)
{
d_xs_mems_owner[i].ToDevice(h_xs[i].data());
}
ck_tile::DeviceMem d_y_mem(h_y);
d_y_mem.SetZero();
auto d_x_ptrs_tuple = [&]<std::size_t... Is>(std::index_sequence<Is...>)
{
return ck_tile::make_tuple(
static_cast<const XDataType*>(d_xs_mems_owner[Is].GetDeviceBuffer())...);
}
(std::make_index_sequence<NumInputs>{});
YDataType* p_y_device = static_cast<YDataType*>(d_y_mem.GetDeviceBuffer());
// Problem and Policy
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
ComputeDataType,
YDataType,
TestElementWiseShape,
ElementwiseOpType>;
using Policy = ck_tile::ElementWiseDefaultPolicy;
ck_tile::ElementWiseKernel<Problem, Policy> ew_kernel;
// Launch configuration
ck_tile::index_t grid_size =
(total_m_elements + TestElementWiseShape::kBlockM - 1) / TestElementWiseShape::kBlockM;
dim3 grid(grid_size, 1, 1);
dim3 block(TestElementWiseShape::kBlockSize, 1, 1);
constexpr ck_tile::index_t kBlockPerCu = 1;
ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log
// Check if the kernel configuration is supported
if(!ew_kernel.IsSupportedArgument(lens))
{
throw std::runtime_error(
"The kernel configuration is not supported for the given input size.");
}
ck_tile::launch_kernel(
s,
ck_tile::make_kernel<TestElementWiseShape::kBlockSize, // MaxThreadPerBlock
kBlockPerCu> // MinBlockPerCu
(ew_kernel,
grid,
block,
0, // actual shared memory
lens,
strides, // input strides
strides, // output strides
d_x_ptrs_tuple,
p_y_device));
d_y_mem.FromDevice(h_y.data());
// Reference computation on host
ElementwiseOpType op_host;
for(ck_tile::index_t i = 0; i < total_m_elements; ++i)
{
auto get_host_op_args = [&]<std::size_t... Is>(std::index_sequence<Is...>)
{
return ck_tile::make_tuple(static_cast<ComputeDataType>(h_xs[Is](i))...);
}
(std::make_index_sequence<NumInputs>{});
YDataType temp_y_val;
ck_tile::apply(
[&](auto&&... host_input_args) {
op_host(temp_y_val,
std::forward<decltype(host_input_args)>(host_input_args)...);
},
get_host_op_args);
h_y_ref(i) = temp_y_val;
}
// Check results
check_err(h_y, h_y_ref, "Error: Incorrect results!", 1e-5, 1e-5);
}
};
// Shape parameters (can be shared or varied per test type)
using Shape1_BlockWarps = ck_tile::sequence<1>; // 1D warp arrangement in M
using Shape1_BlockTile = ck_tile::sequence<256>; // M-dimension of block tile
using Shape1_WarpTile = ck_tile::sequence<64>; // M-dimension of warp tile
// Test configurations
using TestConfig_F32_Add = std::tuple<float,
float,
float,
ck_tile::element_wise::Add,
Shape1_BlockWarps,
Shape1_BlockTile,
Shape1_WarpTile>;
using TestConfig_F32_Relu = std::tuple<float,
float,
float,
ck_tile::element_wise::Relu,
Shape1_BlockWarps,
Shape1_BlockTile,
Shape1_WarpTile>;
using TestConfig_F16_Add = std::tuple<ck_tile::half_t,
ck_tile::half_t,
float, // Compute in float for half
ck_tile::element_wise::Add,
Shape1_BlockWarps,
Shape1_BlockTile,
Shape1_WarpTile>;
using TestTypes = ::testing::Types<TestConfig_F32_Add, TestConfig_F32_Relu, TestConfig_F16_Add>;
TYPED_TEST_SUITE(TestCkTileElementwise, TestTypes);
TYPED_TEST(TestCkTileElementwise, RunElementwise_1024) { this->RunTest(1024); }
TYPED_TEST(TestCkTileElementwise, RunElementwise_513)
{
EXPECT_THROW((this->RunTest(513)),
std::runtime_error); // Test with an input size that's not a multiple of kVectorM
}
TYPED_TEST(TestCkTileElementwise, RunElementwise_516)
{
this->RunTest(516); // Test with an input size that's not a multiple of blockM
}
TYPED_TEST(TestCkTileElementwise, RunElementwise_Small_32)
{
this->RunTest(32); // Test with a very small size
}