mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-14 10:09:41 +00:00
[CK_TILE] Support for elementwise kernel (#2246)
* Elementwise kernel implementation
Co-authored-by: Sami Aario <samaario@amd.com>
Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: yashagar <yashagar@amd.com>
* Elementwise with generalized nDims
* Adding the n-ary input tensor feature
* Generalize dimensions on top of inputs
* Add TFLOPS + remove std usage for tuples
* 1D basecase optimization
* Cleanup code + refactoring to a common interface
* Generalize to unary and add an example
* Cleanup, refactoring and commenting
* Suggestions for LWPCK-3170: elementwise kernel improvements
* Clang-format: remod.py
* Replace InputTensorType with XDataType as the type of input_tensors
* Add Tuple::apply and use it in ElementWiseKernel::operator to call operation with the exact number of arguments in xs
* Move examples to folder 19_elementwise
* Add missing copyright headers and fix some existing ones
* Replace an assert with throw std::runtime_error in elementwise example
* Avoid reading the output by using make_static_distributed_tensor for y_tile
* Removed two unused includes
* No need to move windows to the next block when each workgroup processes a single tile
* Only copy input tensors to the device
* Use get_warp_size to obtain warp size, and use ceiling division for grid size also for the unary example
* Adding output strides to the kernel, transposition example and update the other examples
* Changes made by remod.py
* Use default template parameter values for memory operation and coherence in a call to make_naive_tensor_view
* Move binary operations to include/ck_tile/ops/elementwise/binary_elementwise_operation.hpp
* Reuse generic reference binary/unary operation in examples + refactoring the transpose reference
* Fix comments in elementwise_example.cpp
- Refer to AMD terminology except when suggesting NVIDIA alternatives in parentheses
- ElementWiseTraits was renamed to ElementWiseShape
- Adopt suggestions made by Copilot when prompted to check for factual or typographical errors
* Simplify CMakeLists.txt and remove the unused variables this uncovers
* Rename a file and fix some copyright statements
* Changes made by script/clang-format-overwrite.sh
* Add basic unit test for ElementWiseKernel
* Remove left-over uninformative comment in apply unit test
* Changes made by clang-format-overwrite.sh
* fixup! Use default template parameter values for memory operation and coherence in a call to make_naive_tensor_view
* Clean up test_tuple_apply.cpp and test_elementwise_1d.cpp
* Use make_uniform_array_with_factory to define h_xs and d_xs_mems_owner as type std::array
* Use a DeviceMem constructor that calls get_element_space_size_in_bytes internally
* Move examples to folder 20_elementwise
* Reduced register pressure on the CK tile elementwise kernel + add 4d input example to be able benchmark against old CK
* Fix CLang formating
* Bump up the elementwise example folder number
* Elementwise: add padding + minor cleanup
* Add Vector Size inference + fix issue with wrong vectorization due to missing GuaranteedLastDimensionVectorStride setting in make_naive_tensor_view
* Add isSupportedArg to Elementwise kernel + addapt example and unit tests
* Fix clang-format on the unit test file
---------
Co-authored-by: Damien Lejeune <damien.lejeune@amd.com>
Co-authored-by: Sami Aario <samaario@amd.com>
Co-authored-by: Mohsen Saffari <mohsen.saffari@amd.com>
Co-authored-by: Aviral Goel <aviral.goel@amd.com>
[ROCm/composable_kernel commit: 606b0cc947]
This commit is contained in:
committed by
GitHub
parent
bdb86fee78
commit
094e5bad50
@@ -23,6 +23,7 @@ Documentation for Composable Kernel available at [https://rocm.docs.amd.com/proj
|
||||
* Added Ping-pong scheduler support for GEMM operation along the K dimension.
|
||||
* Added rotating buffer feature for CK_Tile GEMM.
|
||||
* Added int8 support for CK_TILE GEMM.
|
||||
* Added support for elementwise kernel.
|
||||
|
||||
### Optimized
|
||||
|
||||
|
||||
15
example/ck_tile/21_elementwise/CMakeLists.txt
Normal file
15
example/ck_tile/21_elementwise/CMakeLists.txt
Normal file
@@ -0,0 +1,15 @@
|
||||
# Elementwise example targets 2D inputs
|
||||
set(TARGET_NAME_2D_INPUT tile_example_elementwise)
|
||||
add_executable(${TARGET_NAME_2D_INPUT} elementwise_example.cpp)
|
||||
|
||||
# Elementwise unary example targets 2D inputs
|
||||
set(TARGET_NAME_2D_INPUT_UNARY tile_example_elementwise_unary)
|
||||
add_executable(${TARGET_NAME_2D_INPUT_UNARY} elementwise_example_unary.cpp)
|
||||
|
||||
# Elementwise transpose example targets 2D inputs
|
||||
set(TARGET_NAME_2D_INPUT_TRANSPOSE tile_example_elementwise_transpose)
|
||||
add_executable(${TARGET_NAME_2D_INPUT_TRANSPOSE} elementwise_example_transpose.cpp)
|
||||
|
||||
# Elementwise example targets 4D inputs
|
||||
set(TARGET_NAME_4D_INPUT tile_example_elementwise_add_4d)
|
||||
add_executable(${TARGET_NAME_4D_INPUT} elementwise_example_add_4d.cpp)
|
||||
214
example/ck_tile/21_elementwise/elementwise_example.cpp
Normal file
214
example/ck_tile/21_elementwise/elementwise_example.cpp
Normal file
@@ -0,0 +1,214 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "1024", "m dimension")
|
||||
.insert("n", "1024", "n dimension")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to n")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "10", "cold iter")
|
||||
.insert("repeat", "50", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
|
||||
// If stride is negative (default -1), set it to N, assuming a dense row-major layout.
|
||||
if(stride < 0)
|
||||
stride = N;
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
if(stride < N)
|
||||
{
|
||||
throw std::runtime_error("stride must be >= N");
|
||||
}
|
||||
|
||||
// Define type aliases for clarity.
|
||||
// XDataType: Data type of the input tensors.
|
||||
// ComputeDataType: Data type used for intermediate computations (often float for precision).
|
||||
// YDataType: Data type of the output tensor.
|
||||
// XElementwiseOperation: The specific elementwise operation to perform (e.g., Add, Mul).
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType =
|
||||
float; // Using float for intermediate calculations can improve numerical stability.
|
||||
using YDataType = DataType;
|
||||
using XElementwiseOperation = ck_tile::element_wise::Add;
|
||||
|
||||
// 1. Initialize the input data on the host (CPU).
|
||||
// HostTensor is a utility to manage tensor data on the CPU.
|
||||
// The first argument is the shape (dimensions) of the tensor {M, N}.
|
||||
// The second argument is the strides {stride, 1} for row-major layout.
|
||||
// 'x_host_a' and 'x_host_b' are the two input tensors for the elementwise operation.
|
||||
ck_tile::HostTensor<XDataType> x_host_a({M, N}, {stride, 1});
|
||||
ck_tile::HostTensor<XDataType> x_host_b({M, N}, {stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host({M, N}, {stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_validation({M, N}, {stride, 1});
|
||||
|
||||
std::vector<ck_tile::index_t> shape = {M, N};
|
||||
|
||||
// Fill the host tensors with random data.
|
||||
// FillUniformDistribution populates the tensor with values from a uniform distribution,
|
||||
// within an interval.
|
||||
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_a);
|
||||
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_b);
|
||||
|
||||
// 2. Create device memory buffers
|
||||
// DeviceMem allocates memory on the GPU.
|
||||
// The size is determined by the total number of elements and the size of DataType.
|
||||
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes());
|
||||
|
||||
// Copy data from host input tensors to device buffers.
|
||||
x_buf_a.ToDevice(x_host_a.data());
|
||||
x_buf_b.ToDevice(x_host_b.data());
|
||||
|
||||
// 3. Configure the kernel execution parameters.
|
||||
// Dividing the problem into blocktile, blockwarp and warptile
|
||||
// The blocktile is the size of the tile processed by a single work group (also called thread
|
||||
// block). The warptile is the size of the tile processed by a single wavefront (also called
|
||||
// warp). The vector is the size of the tile processed by a single work item (also called
|
||||
// thread). The problem is divided into blocks of size BlockTile. Each block is further divided
|
||||
// into wavefronts of size WarpTile. Each wavefront is composed of 64 work items (on AMD; 32
|
||||
// threads on NVIDIA). Each work item in a wavefront processes one vector's worth of elements.
|
||||
// Note that WarpTile/Vector should be 64 for CDNA (because there are 64 work items per
|
||||
// wavefront). Vector size is set to be 16 / sizeof(ComputeDataType), to maximize vectorization.
|
||||
using BlockTile = ck_tile::sequence<2048>; // How many elements are handled by a block tile (the
|
||||
// tensor is divided into blocks of this size)
|
||||
using BlockWarps = ck_tile::sequence<8>; // How many concurrent wavefronts are in a block (each
|
||||
// wavefront will cover some part of the block tile)
|
||||
|
||||
// WarpTile: Defines the size of the data sub-tile processed by a single wavefront.
|
||||
// This should be consistent with BlockTile and BlockWarps.
|
||||
// If BlockTile is 2048 and BlockWarps is 8, then WarpTile could be 2048/8 = 256.
|
||||
// However, this example uses 64, meaning each wavefront processes 64 elements, and multiple
|
||||
// such wavefront operations might be needed to cover the BlockTile, or the BlockTile is
|
||||
// distributed differently.
|
||||
// The current configuration (BlockTile=2048, BlockWarps=8, WarpTile=64) implies that
|
||||
// each wavefront processes 64 elements, and 8 wavefronts process 8*64 = 512 elements
|
||||
// concurrently. Since 512 is not equal to 2048, it means that warptile(s) will need to iterate
|
||||
// over multiple times over different set of elements to cover the entire BlockTile.
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
|
||||
// 4. Create the kernel
|
||||
|
||||
// ElementWiseShape bundles these tiling parameters.
|
||||
// It calculates derived properties like threads per wavefront, repeats, vectorization and total
|
||||
// block size.
|
||||
using Shape = ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, ComputeDataType>;
|
||||
|
||||
// ElementWisePipelineProblem encapsulates all necessary information for the elementwise kernel:
|
||||
// - Data types (input, compute, output).
|
||||
// - Shape traits (tiling configuration).
|
||||
// - The specific elementwise operation (e.g., Add).
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Shape,
|
||||
XElementwiseOperation>;
|
||||
|
||||
// ElementWiseKernel refers to the GPU kernel class
|
||||
using Kernel = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
|
||||
// Compute flattened size
|
||||
ck_tile::index_t total_elements = 1;
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
|
||||
// kBlockSize: The number of work items in a GPU workgroup (thread block).
|
||||
// This is often a multiple of the wavefront size, 64 on CDNA.
|
||||
// Here, it's explicitly set to 512. This should be consistent with Shape::kBlockSize.
|
||||
// Shape::kBlockSize would be BlockWarps * warpSize (e.g., 8 * 64 = 512).
|
||||
constexpr ck_tile::index_t kBlockSize =
|
||||
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
|
||||
|
||||
// kBlockPerCu: Hint for how many workgroups can be scheduled per Compute Unit (CU).
|
||||
// This can influence occupancy and performance.
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
// kGridSize: Calculates the total number of workgroups required to process all elements.
|
||||
// Each workgroup is responsible for 'elements_per_block' elements.
|
||||
// To ensure all elements are covered, especially when 'total_elements' is not perfectly
|
||||
// divisible by 'elements_per_block', using ceiling division.
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
std::cout << "grid size = " << kGridSize << std::endl;
|
||||
std::cout << "Total elements = " << total_elements << std::endl;
|
||||
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()),
|
||||
static_cast<XDataType*>(x_buf_b.GetDeviceBuffer()));
|
||||
|
||||
auto input_size = ck_tile::make_tuple(M, N);
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!Kernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The kernel configuration is not supported for the given input size.");
|
||||
}
|
||||
|
||||
// 4. Run the kernel
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(N, 1), // Input Stride
|
||||
ck_tile::make_tuple(N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
|
||||
|
||||
std::cout << "Average time: " << ave_time << " ms" << std::endl;
|
||||
|
||||
// 5. Verify the output
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
y_buf.FromDevice(y_validation.data());
|
||||
auto op = [](const auto& v0, const auto& v1) { return v0 + v1; };
|
||||
|
||||
ck_tile::reference_binary_elementwise<XDataType, XDataType, YDataType, ComputeDataType>(
|
||||
x_host_a, x_host_b, y_host, op);
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
159
example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp
Normal file
159
example/ck_tile/21_elementwise/elementwise_example_add_4d.cpp
Normal file
@@ -0,0 +1,159 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("dim0", "4", "dimension 0")
|
||||
.insert("dim1", "16", "dimension 1")
|
||||
.insert("dim2", "32", "dimension 2")
|
||||
.insert("dim3", "32", "dimension 3")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "10", "cold iter")
|
||||
.insert("repeat", "50", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t D0 = arg_parser.get_int("dim0");
|
||||
ck_tile::index_t D1 = arg_parser.get_int("dim1");
|
||||
ck_tile::index_t D2 = arg_parser.get_int("dim2");
|
||||
ck_tile::index_t D3 = arg_parser.get_int("dim3");
|
||||
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType =
|
||||
float; // Using float for intermediate calculations can improve numerical stability.
|
||||
using YDataType = DataType;
|
||||
using XElementwiseOperation = ck_tile::element_wise::Add;
|
||||
|
||||
// Initialize the input data on the host (CPU).
|
||||
std::vector<ck_tile::index_t> problem_shape = {D0, D1, D2, D3};
|
||||
|
||||
std::vector<ck_tile::index_t> host_strides(4);
|
||||
host_strides[3] = 1;
|
||||
host_strides[2] = problem_shape[3];
|
||||
host_strides[1] = problem_shape[2] * problem_shape[3];
|
||||
host_strides[0] = problem_shape[1] * problem_shape[2] * problem_shape[3];
|
||||
|
||||
ck_tile::HostTensor<XDataType> x_host_a(problem_shape, host_strides);
|
||||
ck_tile::HostTensor<XDataType> x_host_b(problem_shape, host_strides);
|
||||
ck_tile::HostTensor<YDataType> y_host(problem_shape, host_strides);
|
||||
ck_tile::HostTensor<YDataType> y_validation(problem_shape, host_strides);
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_a);
|
||||
ck_tile::FillUniformDistribution<XDataType>{2.f, 10.f}(x_host_b);
|
||||
|
||||
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem x_buf_b(x_host_b.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes());
|
||||
|
||||
x_buf_a.ToDevice(x_host_a.data());
|
||||
x_buf_b.ToDevice(x_host_b.data());
|
||||
|
||||
using BlockTile = ck_tile::sequence<256>;
|
||||
using BlockWarps = ck_tile::sequence<1>;
|
||||
using WarpTile = ck_tile::sequence<256>;
|
||||
|
||||
using Shape = ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, ComputeDataType>;
|
||||
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Shape,
|
||||
XElementwiseOperation>;
|
||||
|
||||
using Kernel = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
|
||||
ck_tile::index_t total_elements = 1;
|
||||
for(auto d : problem_shape)
|
||||
total_elements *= d;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize =
|
||||
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
|
||||
|
||||
constexpr ck_tile::index_t kBlockPerCu = 2;
|
||||
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
std::cout << "grid size = " << kGridSize << std::endl;
|
||||
std::cout << "Total elements = " << total_elements << std::endl;
|
||||
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()),
|
||||
static_cast<XDataType*>(x_buf_b.GetDeviceBuffer()));
|
||||
|
||||
auto problem_shape_tuple =
|
||||
ck_tile::make_tuple(problem_shape[0], problem_shape[1], problem_shape[2], problem_shape[3]);
|
||||
|
||||
auto strides_tuple =
|
||||
ck_tile::make_tuple(host_strides[0], host_strides[1], host_strides[2], host_strides[3]);
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!Kernel::IsSupportedArgument(problem_shape_tuple))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The kernel configuration is not supported for the given input size.");
|
||||
}
|
||||
|
||||
// Run the kernel
|
||||
float ave_time = launch_kernel(
|
||||
ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
problem_shape_tuple, // ck_tile::tuple<index_t, index_t, index_t, index_t>
|
||||
strides_tuple, // ck_tile::tuple<index_t, index_t, index_t, index_t> for input strides
|
||||
strides_tuple, // ck_tile::tuple<index_t, index_t, index_t, index_t> for output strides
|
||||
input_tensors,
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
|
||||
|
||||
std::cout << "Average time: " << ave_time << " ms" << std::endl;
|
||||
|
||||
// Verify the output
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
y_buf.FromDevice(y_validation.data());
|
||||
auto op = [](const auto& v0, const auto& v1) { return v0 + v1; };
|
||||
|
||||
ck_tile::reference_binary_elementwise<XDataType, XDataType, YDataType, ComputeDataType>(
|
||||
x_host_a, x_host_b, y_host, op);
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
156
example/ck_tile/21_elementwise/elementwise_example_transpose.cpp
Normal file
156
example/ck_tile/21_elementwise/elementwise_example_transpose.cpp
Normal file
@@ -0,0 +1,156 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_transpose.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "1024", "m dimension of input")
|
||||
.insert("n", "1024", "n dimension of input")
|
||||
.insert("stride_in", "-1", "stride for input M dim, if -1 then equal to n")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "10", "cold iter")
|
||||
.insert("repeat", "50", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t stride_in = arg_parser.get_int("stride_in");
|
||||
|
||||
if(stride_in < 0)
|
||||
stride_in = N; // Dense input: stride for M dim is N
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
if(stride_in < N)
|
||||
{
|
||||
throw std::runtime_error("stride_in must be >= N");
|
||||
}
|
||||
|
||||
using XDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using YDataType = DataType;
|
||||
// Use PassThrough operation for transposition (data is moved, not changed)
|
||||
using XElementwiseOperation = ck_tile::element_wise::PassThrough;
|
||||
|
||||
// 1. Initialize the input data on the host (CPU).
|
||||
// Input x_host_a: M x N
|
||||
// Output y_host: N x M (transposed)
|
||||
ck_tile::HostTensor<XDataType> x_host_a({M, N}, {stride_in, 1});
|
||||
// Output tensor y_host will have dimensions N x M.
|
||||
// Assuming dense output, its stride for the N dimension will be M.
|
||||
ck_tile::index_t stride_out_dim0 = M;
|
||||
ck_tile::HostTensor<YDataType> y_host({N, M}, {stride_out_dim0, 1});
|
||||
ck_tile::HostTensor<YDataType> y_validation({N, M}, {stride_out_dim0, 1});
|
||||
|
||||
// The logical shape for the element-wise operation kernel is based on the input tensor's
|
||||
// elements.
|
||||
std::vector<ck_tile::index_t> op_shape_vec = {M, N};
|
||||
auto op_lengths = ck_tile::make_tuple(M, N); // Lens for the kernel
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_a);
|
||||
|
||||
// 2. Create device memory buffers
|
||||
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes()); // y_host is N x M
|
||||
|
||||
x_buf_a.ToDevice(x_host_a.data());
|
||||
|
||||
// 3. Configure the kernel execution parameters.
|
||||
using BlockTile = ck_tile::sequence<1024>;
|
||||
using BlockWarps = ck_tile::sequence<8>;
|
||||
using WarpTile = ck_tile::sequence<64>;
|
||||
|
||||
using Shape = ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, ComputeDataType>;
|
||||
|
||||
// Problem definition for a single input tensor
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
Shape,
|
||||
XElementwiseOperation>;
|
||||
|
||||
using Kernel = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
|
||||
ck_tile::index_t total_elements = M * N;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize = 64 * BlockWarps::at(ck_tile::number<0>{});
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
std::cout << "Input M=" << M << ", N=" << N << ", StrideIn=" << stride_in << std::endl;
|
||||
std::cout << "Output N=" << N << ", M=" << M << ", StrideOut=" << stride_out_dim0 << std::endl;
|
||||
std::cout << "Grid size = " << kGridSize << ", BlockSize = " << kBlockSize << std::endl;
|
||||
std::cout << "Total elements = " << total_elements << std::endl;
|
||||
|
||||
// Input tensors tuple (single input)
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()));
|
||||
// Input strides tuple (tuple of tuples, one for each input)
|
||||
auto input_strides = ck_tile::make_tuple(stride_in, 1);
|
||||
// Output strides (for N x M tensor, dense)
|
||||
auto output_strides = ck_tile::make_tuple(1, stride_out_dim0);
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!Kernel::IsSupportedArgument(op_lengths))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The kernel configuration is not supported for the given input size.");
|
||||
}
|
||||
|
||||
// 4. Run the kernel
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0, // Shared memory
|
||||
op_lengths, // Logical dimensions for the operation (M, N)
|
||||
input_strides, // Strides for input tensor(s)
|
||||
output_strides, // Strides for output tensor (N, M)
|
||||
input_tensors,
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
|
||||
|
||||
std::cout << "Average time: " << ave_time << " ms" << std::endl;
|
||||
|
||||
// 5. Verify the output
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
y_buf.FromDevice(y_validation.data()); // Copy result from device to y_validation
|
||||
ck_tile::reference_transpose_elementwise<XDataType, YDataType>(
|
||||
x_host_a, y_host); // Compute reference on host
|
||||
pass = ck_tile::check_err(
|
||||
y_validation, y_host, "Transpose Error: Incorrect results!", 0.01, 0.01);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
std::cerr << "Unsupported data type: " << data_type << std::endl;
|
||||
return -3;
|
||||
}
|
||||
147
example/ck_tile/21_elementwise/elementwise_example_unary.cpp
Normal file
147
example/ck_tile/21_elementwise/elementwise_example_unary.cpp
Normal file
@@ -0,0 +1,147 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/ops/elementwise.hpp"
|
||||
#include "ck_tile/host/reference/reference_elementwise.hpp"
|
||||
|
||||
auto create_args(int argc, char* argv[])
|
||||
{
|
||||
ck_tile::ArgParser arg_parser;
|
||||
arg_parser.insert("m", "1024", "m dimension")
|
||||
.insert("n", "1024", "n dimension")
|
||||
.insert("stride", "-1", "stride per row, if -1 then equal to n")
|
||||
.insert("v", "1", "cpu validation or not")
|
||||
.insert("prec", "fp16", "precision")
|
||||
.insert("warmup", "10", "cold iter")
|
||||
.insert("repeat", "50", "hot iter");
|
||||
|
||||
bool result = arg_parser.parse(argc, argv);
|
||||
return std::make_tuple(result, arg_parser);
|
||||
}
|
||||
|
||||
template <typename DataType>
|
||||
bool run(const ck_tile::ArgParser& arg_parser)
|
||||
{
|
||||
ck_tile::index_t M = arg_parser.get_int("m");
|
||||
ck_tile::index_t N = arg_parser.get_int("n");
|
||||
ck_tile::index_t stride = arg_parser.get_int("stride");
|
||||
if(stride < 0)
|
||||
stride = N;
|
||||
std::string data_type = arg_parser.get_str("prec");
|
||||
int do_validation = arg_parser.get_int("v");
|
||||
int warmup = arg_parser.get_int("warmup");
|
||||
int repeat = arg_parser.get_int("repeat");
|
||||
|
||||
assert(stride >= N);
|
||||
|
||||
using XDataType = DataType;
|
||||
using YDataType = DataType;
|
||||
using ComputeDataType = float;
|
||||
using XElementwiseOperation = ck_tile::element_wise::UnarySquare;
|
||||
|
||||
// 1. Initialize the input data on the host
|
||||
ck_tile::HostTensor<XDataType> x_host_a({M, N}, {stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_host({M, N}, {stride, 1});
|
||||
ck_tile::HostTensor<YDataType> y_validation({M, N}, {stride, 1});
|
||||
|
||||
std::vector<ck_tile::index_t> shape = {M, N};
|
||||
|
||||
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(x_host_a);
|
||||
|
||||
// 2. Create device memory buffers and copy input data from host to device
|
||||
ck_tile::DeviceMem x_buf_a(x_host_a.get_element_space_size_in_bytes());
|
||||
ck_tile::DeviceMem y_buf(y_host.get_element_space_size_in_bytes());
|
||||
x_buf_a.ToDevice(x_host_a.data());
|
||||
|
||||
// 3. Create the kernel
|
||||
|
||||
// Dividing the problem into blocktile, warptile, and vector
|
||||
using BlockTile = ck_tile::sequence<2048>; // Size of the block tile (Entire problem is divided
|
||||
// into blocks of this size)
|
||||
using BlockWarps = ck_tile::sequence<8>; // How many concurrent warps are in a block (Each warp
|
||||
// will cover some part of blockTile)
|
||||
using WarpTile = ck_tile::sequence<64>; // How many elements are covered by a warp
|
||||
|
||||
using Shape = ck_tile::ElementWiseShape<BlockWarps, BlockTile, WarpTile, ComputeDataType>;
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
|
||||
XDataType, // ComputeDataType is same as
|
||||
// XDataType in the unary case
|
||||
YDataType,
|
||||
Shape,
|
||||
XElementwiseOperation>;
|
||||
|
||||
using Kernel = ck_tile::ElementWiseKernel<Problem, ck_tile::ElementWiseDefaultPolicy>;
|
||||
|
||||
// Compute flattened size
|
||||
ck_tile::index_t total_elements = 1;
|
||||
for(auto d : shape)
|
||||
total_elements *= d;
|
||||
|
||||
constexpr ck_tile::index_t kBlockSize =
|
||||
ck_tile::get_warp_size() * BlockWarps::at(ck_tile::number<0>{});
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
constexpr ck_tile::index_t elements_per_block = BlockTile::at(ck_tile::number<0>{});
|
||||
ck_tile::index_t kGridSize = (total_elements + elements_per_block - 1) / elements_per_block;
|
||||
|
||||
std::cout << "grid size = " << kGridSize << std::endl;
|
||||
std::cout << "Total elements = " << total_elements << std::endl;
|
||||
|
||||
auto input_tensors = ck_tile::make_tuple(static_cast<XDataType*>(x_buf_a.GetDeviceBuffer()));
|
||||
auto input_size = ck_tile::make_tuple(M, N);
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!Kernel::IsSupportedArgument(input_size))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The kernel configuration is not supported for the given input size.");
|
||||
}
|
||||
|
||||
// 4. Run the kernel
|
||||
float ave_time = launch_kernel(ck_tile::stream_config{nullptr, true, 0, warmup, repeat},
|
||||
ck_tile::make_kernel<kBlockSize, kBlockPerCu>(
|
||||
Kernel{},
|
||||
kGridSize,
|
||||
kBlockSize,
|
||||
0,
|
||||
input_size,
|
||||
ck_tile::make_tuple(N, 1), // Input Stride
|
||||
ck_tile::make_tuple(N, 1), // Output Stride
|
||||
input_tensors,
|
||||
static_cast<YDataType*>(y_buf.GetDeviceBuffer())));
|
||||
|
||||
std::cout << "Average time: " << ave_time << " ms" << std::endl;
|
||||
|
||||
// 5. Verify the output
|
||||
bool pass = true;
|
||||
if(do_validation)
|
||||
{
|
||||
y_buf.FromDevice(y_validation.data());
|
||||
|
||||
auto op = [](const auto& v0) { return v0 * v0; };
|
||||
|
||||
ck_tile::reference_unary_elementwise<XDataType, YDataType, YDataType>(x_host_a, y_host, op);
|
||||
|
||||
pass = ck_tile::check_err(
|
||||
y_validation, y_host, "Elementwise Add Error: Incorrect results!", 0.01, 0.01);
|
||||
}
|
||||
|
||||
return pass;
|
||||
}
|
||||
|
||||
int main(int argc, char* argv[])
|
||||
{
|
||||
auto [result, arg_parser] = create_args(argc, argv);
|
||||
if(!result)
|
||||
return -1;
|
||||
|
||||
const std::string data_type = arg_parser.get_str("prec");
|
||||
if(data_type == "fp16")
|
||||
{
|
||||
return run<ck_tile::half_t>(arg_parser) ? 0 : -2;
|
||||
}
|
||||
|
||||
return -3;
|
||||
}
|
||||
@@ -20,6 +20,7 @@ add_subdirectory(17_grouped_gemm)
|
||||
add_subdirectory(18_flatmm)
|
||||
add_subdirectory(19_gemm_multi_d)
|
||||
add_subdirectory(20_grouped_convolution)
|
||||
add_subdirectory(21_elementwise)
|
||||
add_subdirectory(35_batched_transpose)
|
||||
add_subdirectory(37_transpose)
|
||||
add_subdirectory(38_block_scale_gemm)
|
||||
|
||||
@@ -264,10 +264,14 @@ struct tuple : impl::tuple_base<make_index_sequence<sizeof...(T)>, T...>
|
||||
|
||||
#define TP_COM_() static_assert(I < size(), "wrong! out of range")
|
||||
// clang-format off
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const & { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const & { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() & { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) & { TP_COM_(); return get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) && { TP_COM_(); return std::move(*this).template get<I>(); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get() const && { TP_COM_(); return impl::getv<I>(std::move(*this)); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) get(number<I>) const &&{ TP_COM_(); return std::move(*this).template get<I>(); }
|
||||
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at() const { TP_COM_(); return impl::getv<I>(*this); }
|
||||
template<index_t I> CK_TILE_HOST_DEVICE constexpr decltype(auto) at(number<I>) const { TP_COM_(); return get<I>(); }
|
||||
@@ -470,6 +474,12 @@ transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, sequence<Is...>)
|
||||
return make_tuple(f(x.at(number<Is>{}), y.at(number<Is>{}), z.at(number<Is>{}))...);
|
||||
}
|
||||
|
||||
template <typename F, typename Tuple, index_t... Is>
|
||||
constexpr decltype(auto) apply_impl(F&& f, Tuple&& t, sequence<Is...>)
|
||||
{
|
||||
return std::forward<F>(f)(std::forward<Tuple>(t).get(number<Is>{})...);
|
||||
}
|
||||
|
||||
} // namespace detail
|
||||
|
||||
template <typename F, typename X>
|
||||
@@ -493,6 +503,13 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
|
||||
f, x, y, z, typename arithmetic_sequence_gen<0, X::size(), 1>::type{});
|
||||
}
|
||||
|
||||
template <typename F, typename Tuple>
|
||||
constexpr decltype(auto) apply(F&& f, Tuple&& t)
|
||||
{
|
||||
constexpr index_t N = std::decay_t<Tuple>::size();
|
||||
return detail::apply_impl(std::forward<F>(f), std::forward<Tuple>(t), make_index_sequence<N>{});
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
template <typename F, typename X, index_t... Is>
|
||||
|
||||
@@ -38,6 +38,7 @@
|
||||
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
|
||||
#include "ck_tile/host/reference/reference_softmax.hpp"
|
||||
#include "ck_tile/host/reference/reference_topk.hpp"
|
||||
#include "ck_tile/host/reference/reference_transpose.hpp"
|
||||
#include "ck_tile/host/rotating_buffers.hpp"
|
||||
#include "ck_tile/host/stream_config.hpp"
|
||||
#include "ck_tile/host/stream_utils.hpp"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
33
include/ck_tile/host/reference/reference_transpose.hpp
Normal file
33
include/ck_tile/host/reference/reference_transpose.hpp
Normal file
@@ -0,0 +1,33 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host/host_tensor.hpp"
|
||||
#include <thread>
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType, typename BDataType>
|
||||
void reference_transpose_elementwise(const HostTensor<ADataType>& a, HostTensor<BDataType>& b)
|
||||
{
|
||||
ck_tile::index_t M = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[0]);
|
||||
ck_tile::index_t N = static_cast<ck_tile::index_t>(a.mDesc.get_lengths()[1]);
|
||||
|
||||
// Ensure the b tensor is sized correctly for N x M
|
||||
if(static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[0]) != N ||
|
||||
static_cast<ck_tile::index_t>(b.mDesc.get_lengths()[1]) != M)
|
||||
{
|
||||
throw std::runtime_error("Output tensor b has incorrect dimensions for transpose.");
|
||||
}
|
||||
|
||||
auto f = [&](auto i, auto j) {
|
||||
auto v_a = a(i, j);
|
||||
b(j, i) = ck_tile::type_convert<BDataType>(v_a);
|
||||
};
|
||||
|
||||
make_ParallelTensorFunctor(f, M, N)(std::thread::hardware_concurrency());
|
||||
}
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -3,6 +3,11 @@
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp"
|
||||
#include "ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
namespace element_wise {
|
||||
|
||||
struct Add
|
||||
{
|
||||
template <typename Y, typename X0, typename X1>
|
||||
__host__ __device__ constexpr void operator()(Y& y, const X0& x0, const X1& x1) const;
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<double>(double& y, const double& x0, const double& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 + type_convert<half_t>(x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const float& x1) const
|
||||
{
|
||||
y = type_convert<half_t>(x0 + x1);
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const float& x0, const half_t& x1) const
|
||||
{
|
||||
y = type_convert<half_t>(x0) + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<half_t>(half_t& y, const half_t& x0, const half_t& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<float>(float& y, const float& x0, const bf16_t& x1) const
|
||||
{
|
||||
const float x1_tmp = type_convert<float>(x1);
|
||||
y = x0 + x1_tmp;
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bf16_t>(bf16_t& y, const bf16_t& x0, const bf16_t& x1) const
|
||||
{
|
||||
const float x1_tmp = type_convert<float>(x0);
|
||||
const float x2_tmp = type_convert<float>(x1);
|
||||
const float y_tmp = x1_tmp + x2_tmp;
|
||||
y = type_convert<bf16_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<bf16_t>(bf16_t& y, const float& x0, const bf16_t& x1) const
|
||||
{
|
||||
const float x2_tmp = type_convert<float>(x1);
|
||||
const float y_tmp = x0 + x2_tmp;
|
||||
y = type_convert<bf16_t>(y_tmp);
|
||||
}
|
||||
|
||||
template <>
|
||||
__host__ __device__ constexpr void
|
||||
operator()<int8_t>(int8_t& y, const int8_t& x0, const int8_t& x1) const
|
||||
{
|
||||
y = x0 + x1;
|
||||
};
|
||||
};
|
||||
|
||||
} // namespace element_wise
|
||||
} // namespace ck_tile
|
||||
123
include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp
Normal file
123
include/ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp
Normal file
@@ -0,0 +1,123 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_, typename Policy_>
|
||||
struct ElementWiseKernel
|
||||
{
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using ElementWiseOperation = ck_tile::remove_cvref_t<typename Problem::ElementWiseOperation>;
|
||||
|
||||
template <typename... XDataType, typename Dims>
|
||||
CK_TILE_DEVICE void operator()(Dims lens,
|
||||
Dims input_strides,
|
||||
Dims output_strides,
|
||||
const tuple<XDataType...>& input_tensors,
|
||||
YDataType* p_y) const
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
|
||||
// Setup block-level coordinates and transforms
|
||||
const index_t iM = get_block_id() * S::kBlockM;
|
||||
const auto merge_transform = make_merge_transform(lens);
|
||||
|
||||
// Load all input tiles into registers.
|
||||
// The lambda structure here is intended to minimize the lifetime
|
||||
// of intermediate objects (views, windows) used for loading.
|
||||
const auto x_tiles = ck_tile::generate_tuple(
|
||||
[&](auto i) {
|
||||
const auto tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
input_tensors.get(i), lens, input_strides, number<S::kVectorM>{}, number<1>{});
|
||||
|
||||
const auto transformed_tensor = pad_tensor_view(
|
||||
transform_tensor_view(tensor_view,
|
||||
ck_tile::make_tuple(merge_transform),
|
||||
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
|
||||
ck_tile::make_tuple(sequence<0>{})),
|
||||
ck_tile::make_tuple(number<S::kBlockM>{}),
|
||||
sequence<Problem::kPad>{});
|
||||
|
||||
const auto x_window =
|
||||
make_tile_window(transformed_tensor,
|
||||
ck_tile::make_tuple(number<S::kBlockM>{}),
|
||||
{iM},
|
||||
Policy::template MakeXBlockTileDistribution<Problem>());
|
||||
|
||||
return load_tile(x_window);
|
||||
},
|
||||
number<sizeof...(XDataType)>{});
|
||||
|
||||
// Setup output tile in registers.
|
||||
const auto& x_tile0 = x_tiles.get(number<0>{});
|
||||
auto y_tile = make_static_distributed_tensor<YDataType>(x_tile0.get_tile_distribution());
|
||||
|
||||
// Perform element-wise computation.
|
||||
const auto spans = x_tile0.get_distributed_spans();
|
||||
sweep_tile_span(spans[number<0>{}], [&](auto idx) {
|
||||
const auto tile_idx = make_tuple(idx);
|
||||
apply(
|
||||
[&](auto&&... tiles) {
|
||||
ElementWiseOperation{}(y_tile(tile_idx),
|
||||
type_convert<ComputeDataType>(tiles[tile_idx])...);
|
||||
},
|
||||
x_tiles);
|
||||
});
|
||||
|
||||
// Setup output window and store the result tile.
|
||||
const auto y_m_n = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_y, lens, output_strides, number<S::kVectorM>{});
|
||||
|
||||
const auto transformed_y_m_n = pad_tensor_view(
|
||||
transform_tensor_view(y_m_n,
|
||||
ck_tile::make_tuple(merge_transform),
|
||||
ck_tile::make_tuple(make_index_sequence<Dims::size()>{}),
|
||||
ck_tile::make_tuple(sequence<0>{})),
|
||||
ck_tile::make_tuple(number<S::kBlockM>{}),
|
||||
sequence<Problem::kPad>{});
|
||||
|
||||
auto y_window = make_tile_window(transformed_y_m_n,
|
||||
make_tuple(number<S::kBlockM>{}),
|
||||
{iM},
|
||||
y_tile.get_tile_distribution());
|
||||
|
||||
store_tile(y_window, cast_tile<YDataType>(y_tile));
|
||||
}
|
||||
|
||||
template <typename... Ints>
|
||||
CK_TILE_HOST static bool IsSupportedArgument(const ck_tile::tuple<Ints...>& input_sizes)
|
||||
{
|
||||
int total_elements = 1;
|
||||
const auto kVectorM = Problem_::BlockShape::kVectorM;
|
||||
|
||||
apply([&](auto&&... args) { ((total_elements *= args), ...); }, input_sizes);
|
||||
|
||||
if((total_elements % kVectorM) != 0)
|
||||
{
|
||||
if(ck_tile::EnvIsEnabled(CK_TILE_ENV(CK_TILE_LOGGING)))
|
||||
{
|
||||
CK_TILE_ERROR("Conditions not met: total number of input elements (",
|
||||
total_elements,
|
||||
") should be multiple of the vectorization size (",
|
||||
kVectorM,
|
||||
")");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
struct ElementWiseDefaultPolicy
|
||||
{
|
||||
template <typename Problem>
|
||||
CK_TILE_DEVICE static constexpr auto MakeXBlockTileDistribution()
|
||||
{
|
||||
using S = typename Problem::BlockShape;
|
||||
return make_static_tile_distribution(
|
||||
tile_distribution_encoding<sequence<>, // Replicate
|
||||
tuple<sequence<S::kRepeatM,
|
||||
S::kWarpPerBlockM,
|
||||
S::kThreadPerWarpM,
|
||||
S::kVectorM>>, // Hierarchical
|
||||
tuple<sequence<1>, sequence<1>>, // Parallel
|
||||
tuple<sequence<1>, sequence<2>>, // Parallel
|
||||
sequence<1, 1>, // Yield
|
||||
sequence<0, 3>>{} // Yield
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,26 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
typename BlockShape_,
|
||||
typename ElementWiseOperation_,
|
||||
bool kPad_ = true>
|
||||
struct ElementWisePipelineProblem
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
using BlockShape = remove_cvref_t<BlockShape_>;
|
||||
using ElementWiseOperation = remove_cvref_t<ElementWiseOperation_>;
|
||||
static constexpr bool kPad = kPad_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -0,0 +1,29 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/utility/type_traits.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename BlockWarps, typename BlockTile, typename WarpTile, typename ComputeDataType>
|
||||
struct ElementWiseShape
|
||||
{
|
||||
static constexpr index_t kBlockM = BlockTile::at(number<0>{});
|
||||
|
||||
static constexpr index_t kWarpM = WarpTile::at(number<0>{});
|
||||
|
||||
static constexpr index_t kVectorM = 16 / sizeof(ComputeDataType);
|
||||
|
||||
static constexpr index_t kWarpPerBlockM = BlockWarps::at(number<0>{});
|
||||
|
||||
static constexpr index_t kThreadPerWarpM = kWarpM / kVectorM;
|
||||
|
||||
static constexpr index_t kRepeatM = kBlockM / (kWarpPerBlockM * kWarpM);
|
||||
|
||||
static constexpr index_t kBlockSize =
|
||||
ck_tile::get_warp_size() * reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
@@ -1,5 +1,5 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
|
||||
@@ -5,6 +5,8 @@ add_subdirectory(batched_gemm)
|
||||
add_subdirectory(grouped_gemm)
|
||||
add_subdirectory(gemm_multi_d)
|
||||
add_subdirectory(data_type)
|
||||
add_subdirectory(container)
|
||||
add_subdirectory(elementwise)
|
||||
# Not including these tests as there is a bug on gfx90a and gfx942
|
||||
# resulting in "GPU core dump"
|
||||
#add_subdirectory(moe_smoothquant)
|
||||
|
||||
6
test/ck_tile/container/CMakeLists.txt
Normal file
6
test/ck_tile/container/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_ck_tile_tuple_apply test_tuple_apply.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_ck_tile_tuple_apply PRIVATE utility)
|
||||
endif()
|
||||
endif()
|
||||
223
test/ck_tile/container/test_tuple_apply.cpp
Normal file
223
test/ck_tile/container/test_tuple_apply.cpp
Normal file
@@ -0,0 +1,223 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
using namespace ck_tile;
|
||||
|
||||
class TestCkTileTupleApply : public ::testing::Test
|
||||
{
|
||||
public:
|
||||
// Test functors for different scenarios
|
||||
struct AddFunction
|
||||
{
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const
|
||||
{
|
||||
return (args + ...);
|
||||
}
|
||||
};
|
||||
|
||||
struct MultiplyFunction
|
||||
{
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const
|
||||
{
|
||||
return (args * ...);
|
||||
}
|
||||
};
|
||||
|
||||
struct MaxFunction
|
||||
{
|
||||
template <typename T>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a) const
|
||||
{
|
||||
return a;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
CK_TILE_HOST_DEVICE constexpr T operator()(T a, Args... args) const
|
||||
{
|
||||
auto rest_max = operator()(args...);
|
||||
return a > rest_max ? a : rest_max;
|
||||
}
|
||||
};
|
||||
|
||||
struct ReturnTupleFunction
|
||||
{
|
||||
template <typename... Args>
|
||||
CK_TILE_HOST_DEVICE constexpr auto operator()(Args... args) const
|
||||
{
|
||||
return make_tuple(args..., sizeof...(args));
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
TEST_F(TestCkTileTupleApply, BasicArithmetic)
|
||||
{
|
||||
// Test with simple arithmetic operations
|
||||
auto t1 = make_tuple(1, 2, 3);
|
||||
auto result1 = apply(AddFunction{}, t1);
|
||||
EXPECT_EQ(result1, 6);
|
||||
|
||||
auto t2 = make_tuple(2, 3, 4, 5);
|
||||
auto result2 = apply(MultiplyFunction{}, t2);
|
||||
EXPECT_EQ(result2, 120);
|
||||
}
|
||||
|
||||
TEST_F(TestCkTileTupleApply, SingleElement)
|
||||
{
|
||||
// Test with single element tuple
|
||||
auto t1 = make_tuple(42);
|
||||
auto result1 = apply(AddFunction{}, t1);
|
||||
EXPECT_EQ(result1, 42);
|
||||
|
||||
auto result2 = apply(MultiplyFunction{}, t1);
|
||||
EXPECT_EQ(result2, 42);
|
||||
}
|
||||
|
||||
TEST_F(TestCkTileTupleApply, EmptyTuple)
|
||||
{
|
||||
// Test with empty tuple
|
||||
auto t = tuple<>{};
|
||||
auto result = apply([]() { return 100; }, t);
|
||||
EXPECT_EQ(result, 100);
|
||||
}
|
||||
|
||||
TEST_F(TestCkTileTupleApply, DifferentTypes)
|
||||
{
|
||||
// Test with different data types
|
||||
auto t1 = make_tuple(1, 2.5f, 3.0);
|
||||
auto result1 = apply(AddFunction{}, t1);
|
||||
EXPECT_FLOAT_EQ(result1, 6.5f);
|
||||
|
||||
// Test with mixed integer and floating point
|
||||
auto t2 = make_tuple(10, 0.5f);
|
||||
auto result2 = apply(MultiplyFunction{}, t2);
|
||||
EXPECT_FLOAT_EQ(result2, 5.0f);
|
||||
}
|
||||
|
||||
TEST_F(TestCkTileTupleApply, ReturnTuple)
|
||||
{
|
||||
// Test function that returns a tuple
|
||||
auto t = make_tuple(1, 2, 3);
|
||||
auto result = apply(ReturnTupleFunction{}, t);
|
||||
|
||||
EXPECT_EQ(result.get<0>(), 1);
|
||||
EXPECT_EQ(result.get<1>(), 2);
|
||||
EXPECT_EQ(result.get<2>(), 3);
|
||||
EXPECT_EQ(result.get<3>(), 3); // size
|
||||
}
|
||||
|
||||
TEST_F(TestCkTileTupleApply, LambdaFunction)
|
||||
{
|
||||
// Test with lambda functions
|
||||
auto t1 = make_tuple(5, 10, 15);
|
||||
auto result1 = apply([](auto a, auto b, auto c) { return a + b + c; }, t1);
|
||||
EXPECT_EQ(result1, 30);
|
||||
|
||||
// Test lambda with capture
|
||||
int multiplier = 2;
|
||||
auto result2 =
|
||||
apply([multiplier](auto a, auto b) { return (a + b) * multiplier; }, make_tuple(3, 7));
|
||||
EXPECT_EQ(result2, 20);
|
||||
}
|
||||
|
||||
TEST_F(TestCkTileTupleApply, ConstexprContext)
|
||||
{
|
||||
// Test in constexpr context
|
||||
constexpr auto t = make_tuple(2, 3, 4);
|
||||
constexpr auto result = apply(MultiplyFunction{}, t);
|
||||
static_assert(result == 24, "Constexpr apply should work");
|
||||
EXPECT_EQ(result, 24);
|
||||
}
|
||||
|
||||
TEST_F(TestCkTileTupleApply, ReferenceTypes)
|
||||
{
|
||||
// Test with reference types using tie
|
||||
int a = 1, b = 2, c = 3;
|
||||
auto ref_tuple = tie(a, b, c);
|
||||
|
||||
// Function that modifies references
|
||||
apply(
|
||||
[](auto& x, auto& y, auto& z) {
|
||||
x += 10;
|
||||
y += 20;
|
||||
z += 30;
|
||||
},
|
||||
ref_tuple);
|
||||
|
||||
EXPECT_EQ(a, 11);
|
||||
EXPECT_EQ(b, 22);
|
||||
EXPECT_EQ(c, 33);
|
||||
}
|
||||
|
||||
TEST_F(TestCkTileTupleApply, MoveSemantics)
|
||||
{
|
||||
// Test with move semantics
|
||||
auto t = make_tuple(1, 2, 3);
|
||||
auto result = apply(AddFunction{}, std::move(t));
|
||||
EXPECT_EQ(result, 6);
|
||||
}
|
||||
|
||||
TEST_F(TestCkTileTupleApply, NumberTypes)
|
||||
{
|
||||
// Test with ck_tile::number types
|
||||
auto t = make_tuple(number<1>{}, number<2>{}, number<3>{});
|
||||
auto result = apply([](auto a, auto b, auto c) { return a + b + c; }, t);
|
||||
EXPECT_EQ(result, 6);
|
||||
}
|
||||
|
||||
TEST_F(TestCkTileTupleApply, ElementwiseOperation)
|
||||
{
|
||||
// Test simulating elementwise operations
|
||||
auto input1 = make_tuple(1.0f, 2.0f, 3.0f);
|
||||
auto input2 = make_tuple(4.0f, 5.0f, 6.0f);
|
||||
|
||||
auto add_elementwise = [](const auto& a, const auto& b) {
|
||||
return apply(
|
||||
[&b](auto... args_a) {
|
||||
return apply(
|
||||
[args_a...](auto... args_b) { return make_tuple((args_a + args_b)...); }, b);
|
||||
},
|
||||
a);
|
||||
};
|
||||
|
||||
auto result = add_elementwise(input1, input2);
|
||||
|
||||
EXPECT_FLOAT_EQ(result.get<0>(), 5.0f);
|
||||
EXPECT_FLOAT_EQ(result.get<1>(), 7.0f);
|
||||
EXPECT_FLOAT_EQ(result.get<2>(), 9.0f);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
class TestCkTileTupleApplySize : public TestCkTileTupleApply
|
||||
{
|
||||
protected:
|
||||
static constexpr int Size = T::value;
|
||||
};
|
||||
|
||||
using TupleSizes = ::testing::Types<std::integral_constant<int, 1>,
|
||||
std::integral_constant<int, 2>,
|
||||
std::integral_constant<int, 3>,
|
||||
std::integral_constant<int, 4>,
|
||||
std::integral_constant<int, 8>,
|
||||
std::integral_constant<int, 16>>;
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileTupleApplySize, TupleSizes);
|
||||
|
||||
TYPED_TEST(TestCkTileTupleApplySize, GeneratedTupleSum)
|
||||
{
|
||||
constexpr int N = TypeParam::value;
|
||||
|
||||
// Generate tuple with values 1, 2, 3, ..., N
|
||||
constexpr auto t = generate_tuple([](auto i) { return i.value + 1; }, number<N>{});
|
||||
|
||||
// Sum all elements
|
||||
constexpr auto result = apply(TestCkTileTupleApply::AddFunction{}, t);
|
||||
|
||||
// Expected sum: 1 + 2 + ... + N = N*(N+1)/2
|
||||
constexpr int expected = N * (N + 1) / 2;
|
||||
static_assert(result == expected);
|
||||
}
|
||||
6
test/ck_tile/elementwise/CMakeLists.txt
Normal file
6
test/ck_tile/elementwise/CMakeLists.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
if(GPU_TARGETS MATCHES "gfx9")
|
||||
add_gtest_executable(test_ck_tile_elementwise_1d test_elementwise_1d.cpp)
|
||||
if(result EQUAL 0)
|
||||
target_link_libraries(test_ck_tile_elementwise_1d PRIVATE utility)
|
||||
endif()
|
||||
endif()
|
||||
216
test/ck_tile/elementwise/test_elementwise_1d.cpp
Normal file
216
test/ck_tile/elementwise/test_elementwise_1d.cpp
Normal file
@@ -0,0 +1,216 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <vector>
|
||||
#include <cmath> // For std::abs
|
||||
#include <tuple>
|
||||
#include <type_traits> // For std::is_same_v, std::is_floating_point_v
|
||||
#include <utility> // For std::index_sequence, std::forward
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/host.hpp"
|
||||
#include "ck_tile/host/kernel_launch.hpp"
|
||||
#include "ck_tile/ops/elementwise/kernel/elementwise_kernel.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_problem.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_pipeline_default_policy.hpp"
|
||||
#include "ck_tile/ops/elementwise/pipeline/elementwise_shape.hpp"
|
||||
#include "ck_tile/ops/elementwise/binary_elementwise_operation.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
// Traits to get number of inputs for an elementwise operation
|
||||
template <typename Op>
|
||||
struct elementwise_op_traits;
|
||||
|
||||
template <>
|
||||
struct elementwise_op_traits<ck_tile::element_wise::Add>
|
||||
{
|
||||
static constexpr int num_inputs = 2;
|
||||
};
|
||||
template <>
|
||||
struct elementwise_op_traits<ck_tile::element_wise::Relu>
|
||||
{
|
||||
static constexpr int num_inputs = 1;
|
||||
};
|
||||
|
||||
template <std::size_t D, typename F>
|
||||
auto make_uniform_array_with_factory(F&& factory)
|
||||
{
|
||||
return [&]<std::size_t... Is>(std::index_sequence<Is...>)
|
||||
{
|
||||
return std::array<std::invoke_result_t<F, std::size_t>, D>{factory(Is)...};
|
||||
}
|
||||
(std::make_index_sequence<D>{});
|
||||
}
|
||||
|
||||
template <typename Tuple>
|
||||
class TestCkTileElementwise : public ::testing::Test
|
||||
{
|
||||
protected:
|
||||
using XDataType = std::tuple_element_t<0, Tuple>;
|
||||
using YDataType = std::tuple_element_t<1, Tuple>;
|
||||
using ComputeDataType = std::tuple_element_t<2, Tuple>;
|
||||
using ElementwiseOpType = std::tuple_element_t<3, Tuple>;
|
||||
using BlockWarps_ = std::tuple_element_t<4, Tuple>;
|
||||
using BlockTile_ = std::tuple_element_t<5, Tuple>;
|
||||
using WarpTile_ = std::tuple_element_t<6, Tuple>;
|
||||
using TestElementWiseShape =
|
||||
ck_tile::ElementWiseShape<BlockWarps_, BlockTile_, WarpTile_, ComputeDataType>;
|
||||
static constexpr int NumInputs = elementwise_op_traits<ElementwiseOpType>::num_inputs;
|
||||
|
||||
void RunTest(ck_tile::index_t total_m_elements)
|
||||
{
|
||||
// Dims and Strides (1D example)
|
||||
auto lens = ck_tile::make_tuple(total_m_elements);
|
||||
auto strides = ck_tile::make_tuple(
|
||||
static_cast<ck_tile::index_t>(1)); // Strides for the single dimension
|
||||
|
||||
// Host Tensors
|
||||
auto h_xs = make_uniform_array_with_factory<NumInputs>([&](std::size_t) {
|
||||
auto ret = ck_tile::HostTensor<XDataType>({total_m_elements});
|
||||
ck_tile::FillUniformDistribution<XDataType>{0.f, 5.f}(ret);
|
||||
return ret;
|
||||
});
|
||||
ck_tile::HostTensor<YDataType> h_y({total_m_elements});
|
||||
h_y.SetZero();
|
||||
ck_tile::HostTensor<YDataType> h_y_ref({total_m_elements});
|
||||
h_y_ref.SetZero();
|
||||
|
||||
// Device Buffers
|
||||
auto d_xs_mems_owner = make_uniform_array_with_factory<NumInputs>(
|
||||
[&](std::size_t i) { return ck_tile::DeviceMem(h_xs[i]); });
|
||||
for(int i = 0; i < NumInputs; ++i)
|
||||
{
|
||||
d_xs_mems_owner[i].ToDevice(h_xs[i].data());
|
||||
}
|
||||
|
||||
ck_tile::DeviceMem d_y_mem(h_y);
|
||||
d_y_mem.SetZero();
|
||||
|
||||
auto d_x_ptrs_tuple = [&]<std::size_t... Is>(std::index_sequence<Is...>)
|
||||
{
|
||||
return ck_tile::make_tuple(
|
||||
static_cast<const XDataType*>(d_xs_mems_owner[Is].GetDeviceBuffer())...);
|
||||
}
|
||||
(std::make_index_sequence<NumInputs>{});
|
||||
|
||||
YDataType* p_y_device = static_cast<YDataType*>(d_y_mem.GetDeviceBuffer());
|
||||
|
||||
// Problem and Policy
|
||||
using Problem = ck_tile::ElementWisePipelineProblem<XDataType,
|
||||
ComputeDataType,
|
||||
YDataType,
|
||||
TestElementWiseShape,
|
||||
ElementwiseOpType>;
|
||||
using Policy = ck_tile::ElementWiseDefaultPolicy;
|
||||
|
||||
ck_tile::ElementWiseKernel<Problem, Policy> ew_kernel;
|
||||
|
||||
// Launch configuration
|
||||
ck_tile::index_t grid_size =
|
||||
(total_m_elements + TestElementWiseShape::kBlockM - 1) / TestElementWiseShape::kBlockM;
|
||||
dim3 grid(grid_size, 1, 1);
|
||||
dim3 block(TestElementWiseShape::kBlockSize, 1, 1);
|
||||
constexpr ck_tile::index_t kBlockPerCu = 1;
|
||||
|
||||
ck_tile::stream_config s{nullptr, false, 0}; // Default stream, no timing, no log
|
||||
|
||||
// Check if the kernel configuration is supported
|
||||
if(!ew_kernel.IsSupportedArgument(lens))
|
||||
{
|
||||
throw std::runtime_error(
|
||||
"The kernel configuration is not supported for the given input size.");
|
||||
}
|
||||
|
||||
ck_tile::launch_kernel(
|
||||
s,
|
||||
ck_tile::make_kernel<TestElementWiseShape::kBlockSize, // MaxThreadPerBlock
|
||||
kBlockPerCu> // MinBlockPerCu
|
||||
(ew_kernel,
|
||||
grid,
|
||||
block,
|
||||
0, // actual shared memory
|
||||
lens,
|
||||
strides, // input strides
|
||||
strides, // output strides
|
||||
d_x_ptrs_tuple,
|
||||
p_y_device));
|
||||
|
||||
d_y_mem.FromDevice(h_y.data());
|
||||
|
||||
// Reference computation on host
|
||||
ElementwiseOpType op_host;
|
||||
for(ck_tile::index_t i = 0; i < total_m_elements; ++i)
|
||||
{
|
||||
auto get_host_op_args = [&]<std::size_t... Is>(std::index_sequence<Is...>)
|
||||
{
|
||||
return ck_tile::make_tuple(static_cast<ComputeDataType>(h_xs[Is](i))...);
|
||||
}
|
||||
(std::make_index_sequence<NumInputs>{});
|
||||
|
||||
YDataType temp_y_val;
|
||||
ck_tile::apply(
|
||||
[&](auto&&... host_input_args) {
|
||||
op_host(temp_y_val,
|
||||
std::forward<decltype(host_input_args)>(host_input_args)...);
|
||||
},
|
||||
get_host_op_args);
|
||||
h_y_ref(i) = temp_y_val;
|
||||
}
|
||||
|
||||
// Check results
|
||||
check_err(h_y, h_y_ref, "Error: Incorrect results!", 1e-5, 1e-5);
|
||||
}
|
||||
};
|
||||
|
||||
// Shape parameters (can be shared or varied per test type)
|
||||
using Shape1_BlockWarps = ck_tile::sequence<1>; // 1D warp arrangement in M
|
||||
using Shape1_BlockTile = ck_tile::sequence<256>; // M-dimension of block tile
|
||||
using Shape1_WarpTile = ck_tile::sequence<64>; // M-dimension of warp tile
|
||||
|
||||
// Test configurations
|
||||
using TestConfig_F32_Add = std::tuple<float,
|
||||
float,
|
||||
float,
|
||||
ck_tile::element_wise::Add,
|
||||
Shape1_BlockWarps,
|
||||
Shape1_BlockTile,
|
||||
Shape1_WarpTile>;
|
||||
|
||||
using TestConfig_F32_Relu = std::tuple<float,
|
||||
float,
|
||||
float,
|
||||
ck_tile::element_wise::Relu,
|
||||
Shape1_BlockWarps,
|
||||
Shape1_BlockTile,
|
||||
Shape1_WarpTile>;
|
||||
|
||||
using TestConfig_F16_Add = std::tuple<ck_tile::half_t,
|
||||
ck_tile::half_t,
|
||||
float, // Compute in float for half
|
||||
ck_tile::element_wise::Add,
|
||||
Shape1_BlockWarps,
|
||||
Shape1_BlockTile,
|
||||
Shape1_WarpTile>;
|
||||
|
||||
using TestTypes = ::testing::Types<TestConfig_F32_Add, TestConfig_F32_Relu, TestConfig_F16_Add>;
|
||||
|
||||
TYPED_TEST_SUITE(TestCkTileElementwise, TestTypes);
|
||||
|
||||
TYPED_TEST(TestCkTileElementwise, RunElementwise_1024) { this->RunTest(1024); }
|
||||
|
||||
TYPED_TEST(TestCkTileElementwise, RunElementwise_513)
|
||||
{
|
||||
EXPECT_THROW((this->RunTest(513)),
|
||||
std::runtime_error); // Test with an input size that's not a multiple of kVectorM
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileElementwise, RunElementwise_516)
|
||||
{
|
||||
this->RunTest(516); // Test with an input size that's not a multiple of blockM
|
||||
}
|
||||
|
||||
TYPED_TEST(TestCkTileElementwise, RunElementwise_Small_32)
|
||||
{
|
||||
this->RunTest(32); // Test with a very small size
|
||||
}
|
||||
Reference in New Issue
Block a user