From 1ea1adcc38367e33764faa306695c9f09f4116c6 Mon Sep 17 00:00:00 2001 From: Damien Lejeune Date: Fri, 23 Jan 2026 11:48:52 -0500 Subject: [PATCH] WIP: start kernel implementation + test structure --- test/ck_tile/CMakeLists.txt | 1 + test/ck_tile/mhc/test_mhc.cpp | 4 +- test/ck_tile/mhc/test_mhc_impl.hpp | 297 +++++++++++++++-------------- 3 files changed, 158 insertions(+), 144 deletions(-) diff --git a/test/ck_tile/CMakeLists.txt b/test/ck_tile/CMakeLists.txt index 70649ed8f8..6c347eaf6b 100644 --- a/test/ck_tile/CMakeLists.txt +++ b/test/ck_tile/CMakeLists.txt @@ -41,3 +41,4 @@ add_subdirectory(fmha) add_subdirectory(gemm_tile_engine) add_subdirectory(pooling) add_subdirectory(grouped_conv) +add_subdirectory(mhc) diff --git a/test/ck_tile/mhc/test_mhc.cpp b/test/ck_tile/mhc/test_mhc.cpp index 7ff66173ac..5aeb2e7258 100644 --- a/test/ck_tile/mhc/test_mhc.cpp +++ b/test/ck_tile/mhc/test_mhc.cpp @@ -10,11 +10,9 @@ #include "ck_tile/core.hpp" #include "ck_tile/host.hpp" -#include "ck_tile/ops/reduce.hpp" #include "ck_tile/host/kernel_launch.hpp" -#include "ck_tile/ops/elementwise.hpp" -#include "test_multi_reduce2d_multiblock_impl.hpp" +#include "test_mhc_impl.hpp" // Shape parameters for different test configurations using Shape1_BlockWarps = ck_tile::sequence<4, 1>; diff --git a/test/ck_tile/mhc/test_mhc_impl.hpp b/test/ck_tile/mhc/test_mhc_impl.hpp index d0363eb62a..fc9ec695ae 100644 --- a/test/ck_tile/mhc/test_mhc_impl.hpp +++ b/test/ck_tile/mhc/test_mhc_impl.hpp @@ -43,166 +43,181 @@ class TestCkTileMHC : public ::testing::Test // ReduceDimSeq reduce_dims) void RunGenericTest() { - static_assert( - ReduceOpsType::size() == ElementwiseOpsType::size() && - ReduceOpsType::size() == AccumulatorOpsType::size() && - ReduceOpsType::size() == InterBlockReduceOpsType::size(), - "Error: All operations tuple size must match the number of reduction operations"); - const auto number_operations = ReduceOpsType::size(); + // Test parameters + const int B = 8; // Batch size + const int n = 4; // Expansion rate (aka streams) + const int C = 256; // Output layer dim + const int nC = n * C; // Total input dimension - ck_tile::HostTensor h_x(input_shape, input_strides); + const int output_dim = 2 * n + n * n; // 2n + n^2 = 8 + 16 = 24 for n=4 - auto h_ys = ck_tile::generate_tuple( - [&output_shape, &output_strides](auto /*i*/) { - return ck_tile::HostTensor(output_shape, output_strides); - }, - ck_tile::number{}); + // Allocate host tensors + ck_tile::HostTensor h_x({B, nC}); // Input [B, nC] + ck_tile::HostTensor h_phi({nC, output_dim}); // Weights [nC, 2n+n^2] + ck_tile::HostTensor h_output({B, output_dim}); // Output [B, 2n+n^2] - auto h_ys_ref = ck_tile::generate_tuple( - [&output_shape, &output_strides](auto /*i*/) { - return ck_tile::HostTensor(output_shape, output_strides); - }, - ck_tile::number{}); + // Initialize with random data + ck_tile::FillUniformDistribution{-1.0f, 1.0f}(h_x); + ck_tile::FillUniformDistribution{-0.5f, 0.5f}(h_phi); + h_output.SetZero(); - ck_tile::FillUniformDistribution{-5.f, 5.f}(h_x); - - ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { - h_ys.template at().SetZero(); - h_ys_ref.template at().SetZero(); - }); - - auto output_number_elements = [&output_shape]() { - ck_tile::index_t prod = 1; - for(auto len : output_shape) - prod *= len; - return prod; - }(); - - auto output_buffer_size = - number_operations * h_ys.get(ck_tile::number<0>{}).get_element_space_size_in_bytes(); + // Allocate device memory ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); - ck_tile::DeviceMem d_y_mem(output_buffer_size); - - std::vector h(number_operations * output_number_elements); - - // Init the output data with identity values respective to each reduce op - ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { - constexpr auto op = ReduceOpsType{}.at(i); - const auto identity_val = op.template GetIdentityValue(); - std::fill(h.begin() + i * output_number_elements, - h.begin() + (i + 1) * output_number_elements, - identity_val); - }); + ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes()); + ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes()); + // Copy data to device d_x_mem.ToDevice(h_x.data()); - d_y_mem.ToDevice(h.data()); + d_phi_mem.ToDevice(h_phi.data()); + d_output_mem.ToDevice(h_output.data()); - using Problem = ck_tile::Reduce2dProblem; - - using Kernel = ck_tile::MultiReduceMultiblock; - - // Launch configuration - const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + // Kernel launch configuration + const ck_tile::index_t kBlockSize = 256; // 256 threads per block + const ck_tile::index_t kGridSize = B; // One block per batch element constexpr ck_tile::index_t kBlockPerCu = 1; - auto elementwise_ops = - make_elementwise_ops_tuple(total_reduce_elements, ElementwiseOpsType{}); - auto accumulator_ops = - make_elementwise_ops_tuple(total_reduce_elements, AccumulatorOpsType{}); + // TODO: Define Problem and Policy types + // using Problem = ck_tile::MHCProblem<...>; + // using Kernel = ck_tile::ManifoldConstrainedHyperConnection; - auto [num_block_tile_iterations, block_group_size] = - typename Kernel::TilePartitioner{total_reduce_elements}.GetBlockGroupParams(); - - std::cout << "Block group size: " << block_group_size - << ", Num block tile iterations: " << num_block_tile_iterations - << ", Reduce total length: " << total_reduce_elements << std::endl; - - ck_tile::index_t kGridSize = - ((kept_dim_len_prod + TestReduce2dShape::Block_M - 1) / TestReduce2dShape::Block_M) * - block_group_size; - - // Generic helper to create tuple from vector based on compile-time size - auto make_shape_tuple = [](const std::vector& vec) { - return [&vec](std::index_sequence) { - return ck_tile::make_tuple(vec[I]...); - }(std::make_index_sequence{}); - }; - - auto input_shape_tuple = make_shape_tuple.template operator()(input_shape); - auto input_strides_tuple = make_shape_tuple.template operator()(input_strides); - - if(!Kernel::IsSupportedArgument( - total_reduce_elements, - input_strides_tuple)) // output tensor's continuous dimension - { - throw std::runtime_error("Wrong! Arguments not supported!\n"); - } + std::cout << "Launching MHC kernel with:" << std::endl; + std::cout << " Batch size (B): " << B << std::endl; + std::cout << " Expansion factor (n): " << n << std::endl; + std::cout << " Channels per stream (C): " << C << std::endl; + std::cout << " Input dimension (nC): " << nC << std::endl; + std::cout << " Output dimension (2n+n²): " << output_dim << std::endl; + std::cout << " Grid size: " << kGridSize << std::endl; + std::cout << " Block size: " << kBlockSize << std::endl; + // Kernel launch + /* ck_tile::launch_kernel( ck_tile::stream_config{nullptr, false, 0}, - ck_tile::make_kernel(Kernel{}, - kGridSize, - kBlockSize, - 0, - static_cast(d_x_mem.GetDeviceBuffer()), - static_cast(d_y_mem.GetDeviceBuffer()), - input_shape_tuple, - input_strides_tuple, - kept_dims, - reduce_dims, - output_number_elements, - elementwise_ops, - accumulator_ops, - InterBlockReduceOpsType{})); + ck_tile::make_kernel( + Kernel{}, + kGridSize, + kBlockSize, + 0, // shared memory size + static_cast(d_x_mem.GetDeviceBuffer()), + static_cast(d_phi_mem.GetDeviceBuffer()), + static_cast(d_output_mem.GetDeviceBuffer()), + B, n, C)); + */ - // Reference computation - ck_tile::reference_multiple_reduce_multiblock( - h_x, - h_ys_ref, - ReduceOpsType{}, - kept_dims, - reduce_dims, - elementwise_ops, - accumulator_ops, - InterBlockReduceOpsType{}, - block_group_size); + // Copy results back to host + // d_output_mem.FromDevice(h_output.data()); - // Calculate proper error thresholds based on data types and number of accumulations - // const auto rtol = ck_tile::get_relative_threshold( - // total_reduce_elements); - // const auto atol = ck_tile::get_absolute_threshold( - // 5.0f, total_reduce_elements); + // TODO: Add reference computation and validation - // Unfortunately due to the non-sequenciality, down-casting on the output buffer - // and further operations on this buffer, the error is compounding at a faster - // rate than what the host reference can support. A large tolerance is then required - const auto rtol = 1e-2; - const auto atol = 1e-1; + // auto h_ys = ck_tile::generate_tuple( + // [&output_shape, &output_strides](auto /*i*/) { + // return ck_tile::HostTensor(output_shape, output_strides); + // }, + // ck_tile::number{}); - // Transfer data from device and check error for each operation - std::vector h_y_tmp(output_number_elements * number_operations); - d_y_mem.FromDevice(h_y_tmp.data()); - bool result = true; - ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { - std::memcpy(h_ys.get(ck_tile::number{}).data(), - h_y_tmp.data() + i * output_number_elements, - output_number_elements * sizeof(YDataType)); - std::cout << "Checking errors for operation: " << i << std::endl; - result &= ck_tile::check_err(h_ys.get(ck_tile::number{}), - h_ys_ref.get(ck_tile::number{}), - "Error: Incorrect reduce results!", - rtol, - atol); - }); + // auto h_ys_ref = ck_tile::generate_tuple( + // [&output_shape, &output_strides](auto /*i*/) { + // return ck_tile::HostTensor(output_shape, output_strides); + // }, + // ck_tile::number{}); - EXPECT_TRUE(result); + // ck_tile::FillUniformDistribution{-5.f, 5.f}(h_x); + + // ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + // h_ys.template at().SetZero(); + // h_ys_ref.template at().SetZero(); + // }); + + // auto output_number_elements = [&output_shape]() { + // ck_tile::index_t prod = 1; + // for(auto len : output_shape) + // prod *= len; + // return prod; + // }(); + + // auto output_buffer_size = + // number_operations * h_ys.get(ck_tile::number<0>{}).get_element_space_size_in_bytes(); + // ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + // ck_tile::DeviceMem d_y_mem(output_buffer_size); + + // std::vector h(number_operations * output_number_elements); + + // // Init the output data with identity values respective to each reduce op + // ck_tile::static_for<0, number_operations, 1>{}([&](auto i) { + // constexpr auto op = ReduceOpsType{}.at(i); + // const auto identity_val = op.template GetIdentityValue(); + // std::fill(h.begin() + i * output_number_elements, + // h.begin() + (i + 1) * output_number_elements, + // identity_val); + // }); + + // d_x_mem.ToDevice(h_x.data()); + // d_y_mem.ToDevice(h.data()); + + // using Problem = ck_tile::Reduce2dProblem; + + // using Kernel = ck_tile::MultiReduceMultiblock; + + // // Launch configuration + // const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + // constexpr ck_tile::index_t kBlockPerCu = 1; + + // auto elementwise_ops = + // make_elementwise_ops_tuple(total_reduce_elements, ElementwiseOpsType{}); + // auto accumulator_ops = + // make_elementwise_ops_tuple(total_reduce_elements, AccumulatorOpsType{}); + + // auto [num_block_tile_iterations, block_group_size] = + // typename Kernel::TilePartitioner{total_reduce_elements}.GetBlockGroupParams(); + + // std::cout << "Block group size: " << block_group_size + // << ", Num block tile iterations: " << num_block_tile_iterations + // << ", Reduce total length: " << total_reduce_elements << std::endl; + + // ck_tile::index_t kGridSize = + // ((kept_dim_len_prod + TestReduce2dShape::Block_M - 1) / TestReduce2dShape::Block_M) * + // block_group_size; + + // // Generic helper to create tuple from vector based on compile-time size + // auto make_shape_tuple = [](const std::vector& vec) { + // return [&vec](std::index_sequence) { + // return ck_tile::make_tuple(vec[I]...); + // }(std::make_index_sequence{}); + // }; + + // auto input_shape_tuple = make_shape_tuple.template operator()(input_shape); + // auto input_strides_tuple = make_shape_tuple.template operator()(input_strides); + + // if(!Kernel::IsSupportedArgument()) // TODO + // { + // } + + // ck_tile::launch_kernel( + // ck_tile::stream_config{nullptr, false, 0}, + // ck_tile::make_kernel(Kernel{}, + // kGridSize, + // kBlockSize, + // 0, + // static_cast(d_x_mem.GetDeviceBuffer()), + // static_cast(d_y_mem.GetDeviceBuffer()), + // input_shape_tuple, + // input_strides_tuple, + // kept_dims, + // reduce_dims, + // output_number_elements, + // elementwise_ops, + // accumulator_ops, + // InterBlockReduceOpsType{})); + + // TODO: Reference computation + Transfer data back to host + EXPECT_TRUE(true); } };