From e7ebd6c288212d22a05bd2c5004244be8c2ce454 Mon Sep 17 00:00:00 2001 From: Damien Lejeune Date: Fri, 6 Feb 2026 09:44:20 +0000 Subject: [PATCH] Readd naive normalization in mhc v3 --- .../ck_tile/host/reference/reference_mhc.hpp | 30 +- .../ops/mhc/kernel/mhc_kernel_tile_v3.hpp | 85 +- test/ck_tile/mhc/test_mhc.cpp | 35 +- test/ck_tile/mhc/test_mhc_impl.hpp | 1097 +++++++++-------- 4 files changed, 704 insertions(+), 543 deletions(-) diff --git a/include/ck_tile/host/reference/reference_mhc.hpp b/include/ck_tile/host/reference/reference_mhc.hpp index 33134f3b3b..4588d78eac 100644 --- a/include/ck_tile/host/reference/reference_mhc.hpp +++ b/include/ck_tile/host/reference/reference_mhc.hpp @@ -57,13 +57,13 @@ CK_TILE_HOST void reference_mhc(const HostTensor& x_b_nc, // [B type_convert(phi_nc_out(k, out_idx)); } // // Step 4: Apply activation σ(H^{pre}) - // ComputeDataType activated_value; - // activation(activated_value, sum); - // output_b_out(b, out_idx) = - // type_convert((alpha_pre / r) * activated_value + bias); + ComputeDataType activated_value; + activation(activated_value, sum); + output_b_out(b, out_idx) = + type_convert((alpha_pre / norm) * activated_value + bias); // TESTING: Store raw GEMM output - output_b_out(b, out_idx) = type_convert(sum); + // output_b_out(b, out_idx) = type_convert(sum); } // Process H^{post}: x * phi[:, n:2n] -> 2*sigma(output[:, n:2n]) @@ -76,13 +76,13 @@ CK_TILE_HOST void reference_mhc(const HostTensor& x_b_nc, // [B type_convert(phi_nc_out(k, n + out_idx)); } // // Step 5: Apply 2*σ(H^{post}) - // ComputeDataType activated_value; - // activation(activated_value, sum); - // output_b_out(b, n + out_idx) = - // type_convert((alpha_post / r) * 2.0f * activated_value + bias); + ComputeDataType activated_value; + activation(activated_value, sum); + output_b_out(b, n + out_idx) = + type_convert((alpha_post / norm) * 2.0f * activated_value + bias); // TESTING: Store raw GEMM output - output_b_out(b, n + out_idx) = type_convert(sum); + // output_b_out(b, n + out_idx) = type_convert(sum); } // Process H^{res}: x * phi[:, 2n:2n+n^2] -> output[:, 2n:2n+n^2] @@ -95,17 +95,17 @@ CK_TILE_HOST void reference_mhc(const HostTensor& x_b_nc, // [B sum += type_convert(x_b_nc(b, k)) * type_convert(phi_nc_out(k, 2 * n + out_idx)); } - // // Apply: 1/r * alpha_res * sum + bias - // output_b_out(b, 2 * n + out_idx) = - // type_convert((alpha_res / r) * sum + bias); + // Apply: 1/r * alpha_res * sum + bias + output_b_out(b, 2 * n + out_idx) = + type_convert((alpha_res / norm) * sum + bias); // TESTING: Store raw GEMM output - output_b_out(b, 2 * n + out_idx) = type_convert(sum); + // output_b_out(b, 2 * n + out_idx) = type_convert(sum); } // Note: norm is computed but not currently used in the output // It could be used for additional normalization if needed - (void)norm; + // (void)norm; }; make_ParallelTensorFunctor(f_batch, B)(std::thread::hardware_concurrency()); diff --git a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp index 795414dc3e..5a1156ac3d 100644 --- a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp +++ b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp @@ -145,6 +145,53 @@ struct MHCKernelV3 auto result_tile = gemm_pipeline( make_tuple(x_dram_window), make_tuple(phi_dram_window), num_k_loops, smem); + // Compute norm ||x_l||_2 / sqrt(nC) for each batch element using vectorized loads + // Use vector loads (float4) for better memory bandwidth utilization + constexpr index_t kVectorSize = 4; // Load 4 floats at a time + + ComputeDataType norms[kMTile]; + + for(index_t local_m = 0; local_m < kMTile; ++local_m) + { + const index_t global_m = batch_start + local_m; + if(global_m < batch) + { + ComputeDataType sum_squares = 0.0f; + const XDataType* row_ptr = p_x + global_m * nC; + + // Vectorized loop: process kVectorSize elements at a time + index_t k = 0; + for(; k + kVectorSize <= nC; k += kVectorSize) + { + // Load vector of elements + using VecType = ext_vector_t; + VecType vec_data = *c_style_pointer_cast(row_ptr + k); + +// Accumulate squares +#pragma unroll + for(index_t i = 0; i < kVectorSize; ++i) + { + ComputeDataType val = type_convert(vec_data[i]); + sum_squares += val * val; + } + } + + // Handle remaining elements (scalar loop) + for(; k < nC; ++k) + { + ComputeDataType val = type_convert(row_ptr[k]); + sum_squares += val * val; + } + + norms[local_m] = + ck_tile::sqrt(sum_squares) / ck_tile::sqrt(static_cast(nC)); + } + else + { + norms[local_m] = 1.0f; // Default for out-of-bounds + } + } + // Apply elementwise operations (currently commented out for GEMM testing) constexpr auto result_spans = decltype(result_tile)::get_distributed_spans(); @@ -163,26 +210,26 @@ struct MHCKernelV3 constexpr auto i_j_idx = make_tuple(idx0, idx1); [[maybe_unused]] ComputeDataType value = result_tile[i_j_idx]; - // TESTING: Comment out post-GEMM operations to validate GEMM only - // // Apply activation based on output section - // if(global_n < n) - // { - // ComputeDataType activated_value; - // Activation{}(activated_value, value); - // value = (alpha_pre / r) * activated_value + bias; - // } - // else if(global_n < 2 * n) - // { - // ComputeDataType activated_value; - // Activation{}(activated_value, value); - // value = (alpha_post / r) * 2.0f * activated_value + bias; - // } - // else - // { - // value = (alpha_res / r) * value + bias; - // } + // Get the norm for this batch element + const ComputeDataType norm = norms[local_m]; - // p_output[global_m * output_dim + global_n] = type_convert(value); + // Apply activation based on output section + if(global_n < n) + { + ComputeDataType activated_value; + Activation{}(activated_value, value); + result_tile(i_j_idx) = (alpha_pre / norm) * activated_value + bias; + } + else if(global_n < 2 * n) + { + ComputeDataType activated_value; + Activation{}(activated_value, value); + result_tile(i_j_idx) = (alpha_post / norm) * 2.0f * activated_value + bias; + } + else + { + result_tile(i_j_idx) = (alpha_res / norm) * value + bias; + } } }); }); diff --git a/test/ck_tile/mhc/test_mhc.cpp b/test/ck_tile/mhc/test_mhc.cpp index 33c9bb58ef..a2a3be8a13 100644 --- a/test/ck_tile/mhc/test_mhc.cpp +++ b/test/ck_tile/mhc/test_mhc.cpp @@ -37,49 +37,46 @@ TYPED_TEST_SUITE(TestCkTileMHC, TestTypes); // TYPED_TEST(TestCkTileMHC, TestArbitraryBatchSize) { this->RunArbitraryBatchSizeTest(); } // Explicit test cases for specific batch sizes (using template syntax) -TYPED_TEST(TestCkTileMHC, TestBatchSize1) { this->template RunBatchSizeTest<1>(); } +TYPED_TEST(TestCkTileMHC, TestBatchSize1) { this->template RunGemmPipeline<1>(); } -TYPED_TEST(TestCkTileMHC, TestBatchSize16) { this->template RunBatchSizeTest<16>(); } +TYPED_TEST(TestCkTileMHC, TestBatchSize16) { this->template RunGemmPipeline<16>(); } -TYPED_TEST(TestCkTileMHC, TestBatchSize17) { this->template RunBatchSizeTest<17>(); } +TYPED_TEST(TestCkTileMHC, TestBatchSize17) { this->template RunGemmPipeline<17>(); } -TYPED_TEST(TestCkTileMHC, TestBatchSize32) { this->template RunBatchSizeTest<32>(); } +TYPED_TEST(TestCkTileMHC, TestBatchSize32) { this->template RunGemmPipeline<32>(); } -TYPED_TEST(TestCkTileMHC, TestBatchSize64) { this->template RunBatchSizeTest<64>(); } +TYPED_TEST(TestCkTileMHC, TestBatchSize64) { this->template RunGemmPipeline<64>(); } -TYPED_TEST(TestCkTileMHC, TestBatchSize100) { this->template RunBatchSizeTest<100>(); } +TYPED_TEST(TestCkTileMHC, TestBatchSize100) { this->template RunGemmPipeline<100>(); } // Test with different expansion factors (keeping nC <= 256) -TYPED_TEST(TestCkTileMHC, TestBatchSize16N2C128) { this->template RunBatchSizeTest<16, 2, 128>(); } +TYPED_TEST(TestCkTileMHC, TestBatchSize16N2C128) { this->template RunGemmPipeline<16, 2, 128>(); } -TYPED_TEST(TestCkTileMHC, TestBatchSize32N3C85) { this->template RunBatchSizeTest<32, 3, 85>(); } +TYPED_TEST(TestCkTileMHC, TestBatchSize32N3C85) { this->template RunGemmPipeline<32, 3, 85>(); } -TYPED_TEST(TestCkTileMHC, TestBatchSize8N8C32) { this->template RunBatchSizeTest<8, 8, 32>(); } +TYPED_TEST(TestCkTileMHC, TestBatchSize8N8C32) { this->template RunGemmPipeline<8, 8, 32>(); } // Test with large C values that require K-tiling (nC > 256) -TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C512) { this->template RunBatchSizeTest<2, 4, 512>(); } +TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C512) { this->template RunGemmPipeline<2, 4, 512>(); } -TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C1024) { this->template RunBatchSizeTest<2, 4, 1024>(); } +TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C1024) { this->template RunGemmPipeline<2, 4, 1024>(); } -TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C4096) { this->template RunBatchSizeTest<2, 4, 4096>(); } +TYPED_TEST(TestCkTileMHC, TestBatchSize2N4C4096) { this->template RunGemmPipeline<2, 4, 4096>(); } // Test with different activation functions TYPED_TEST(TestCkTileMHC, TestBatchSize16WithTanh) { - this->template RunBatchSizeTestWithActivation<16, 4, 64, ck_tile::element_wise::TanH>(); + this->template RunGemmPipeline<16, 4, 64, ck_tile::element_wise::TanH>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize16WithRelu) { - this->template RunBatchSizeTestWithActivation<16, 4, 64, ck_tile::element_wise::Relu>(); + this->template RunGemmPipeline<16, 4, 64, ck_tile::element_wise::Relu>(); } TYPED_TEST(TestCkTileMHC, TestBatchSize16WithSilu) { - this->template RunBatchSizeTestWithActivation<16, 4, 64, ck_tile::element_wise::Silu>(); + this->template RunGemmPipeline<16, 4, 64, ck_tile::element_wise::Silu>(); } -TYPED_TEST(TestCkTileMHC, TestBatchSize16N4C4096) -{ - this->template RunBatchSizeTest<16, 4, 4096>(); -} +TYPED_TEST(TestCkTileMHC, TestBatchSize16N4C4096) { this->template RunGemmPipeline<16, 4, 4096>(); } diff --git a/test/ck_tile/mhc/test_mhc_impl.hpp b/test/ck_tile/mhc/test_mhc_impl.hpp index 24775541f9..b2e9a7d09c 100644 --- a/test/ck_tile/mhc/test_mhc_impl.hpp +++ b/test/ck_tile/mhc/test_mhc_impl.hpp @@ -44,244 +44,620 @@ class TestCkTileMHC : public ::testing::Test // ck_tile::index_t total_reduce_elements, // KeptDimSeq kept_dims, // ReduceDimSeq reduce_dims) - void RunGenericTest() - { - // Test parameters - const int B = 8; // Batch size - const int n = 4; // Expansion rate (aka streams) - const int C = 64; // Output layer dim (reduced to avoid shared memory overflow) - const int nC = n * C; // Total input dimension + // void RunGenericTest() + // { - const int output_dim = 2 * n + n * n; // 2n + n^2 = 8 + 16 = 24 for n=4 + // // Test parameters + // const int B = 8; // Batch size + // const int n = 4; // Expansion rate (aka streams) + // const int C = 64; // Output layer dim (reduced to avoid shared memory overflow) + // const int nC = n * C; // Total input dimension - // Allocate host tensors - ck_tile::HostTensor h_x({B, nC}); // Input [B, nC] - ck_tile::HostTensor h_phi({nC, output_dim}); // Weights [nC, 2n+n^2] - ck_tile::HostTensor h_output({B, output_dim}); // Output [B, 2n+n^2] + // const int output_dim = 2 * n + n * n; // 2n + n^2 = 8 + 16 = 24 for n=4 - // Initialize with random data - ck_tile::FillUniformDistribution{-1.0f, 1.0f}(h_x); - ck_tile::FillUniformDistribution{-0.5f, 0.5f}(h_phi); - h_output.SetZero(); + // // Allocate host tensors + // ck_tile::HostTensor h_x({B, nC}); // Input [B, nC] + // ck_tile::HostTensor h_phi({nC, output_dim}); // Weights [nC, 2n+n^2] + // ck_tile::HostTensor h_output({B, output_dim}); // Output [B, 2n+n^2] - // Allocate device memory - ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); - ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes()); - ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes()); + // // Initialize with random data + // ck_tile::FillUniformDistribution{-1.0f, 1.0f}(h_x); + // ck_tile::FillUniformDistribution{-0.5f, 0.5f}(h_phi); + // h_output.SetZero(); - // Copy data to device - d_x_mem.ToDevice(h_x.data()); - d_phi_mem.ToDevice(h_phi.data()); - d_output_mem.ToDevice(h_output.data()); + // // Allocate device memory + // ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + // ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes()); + // ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes()); - // DEBUG: Print first few values of h_x to compare with x_lds - std::cout << "DEBUG h_x[0, 0:4]: " << h_x(0, 0) << ", " << h_x(0, 1) << ", " << h_x(0, 2) - << ", " << h_x(0, 3) << std::endl; - std::cout << "DEBUG h_x[1, 0:4]: " << h_x(1, 0) << ", " << h_x(1, 1) << ", " << h_x(1, 2) - << ", " << h_x(1, 3) << std::endl; + // // Copy data to device + // d_x_mem.ToDevice(h_x.data()); + // d_phi_mem.ToDevice(h_phi.data()); + // d_output_mem.ToDevice(h_output.data()); - // DEBUG: Print first few values of h_phi column 0 (stream_id=0) - std::cout << "DEBUG h_phi[0:4, 0]: " << h_phi(0, 0) << ", " << h_phi(1, 0) << ", " - << h_phi(2, 0) << ", " << h_phi(3, 0) << std::endl; + // // DEBUG: Print first few values of h_x to compare with x_lds + // std::cout << "DEBUG h_x[0, 0:4]: " << h_x(0, 0) << ", " << h_x(0, 1) << ", " << h_x(0, 2) + // << ", " << h_x(0, 3) << std::endl; + // std::cout << "DEBUG h_x[1, 0:4]: " << h_x(1, 0) << ", " << h_x(1, 1) << ", " << h_x(1, 2) + // << ", " << h_x(1, 3) << std::endl; - // Define block shape for the kernel - // For simplicity, we use a basic configuration - using BlockShape = - ck_tile::Generic2dBlockShape, // Block tile size [M, N] - 1 - // row, 256 columns - ck_tile::sequence<1, 256>, // Threads per block [M, N] - ck_tile::sequence<1, 1> // Vector size [M, N] - >; + // // DEBUG: Print first few values of h_phi column 0 (stream_id=0) + // std::cout << "DEBUG h_phi[0:4, 0]: " << h_phi(0, 0) << ", " << h_phi(1, 0) << ", " + // << h_phi(2, 0) << ", " << h_phi(3, 0) << std::endl; - // Define the Problem type - using Problem = ck_tile::MHCProblem; + // // Define block shape for the kernel + // // For simplicity, we use a basic configuration + // using BlockShape = + // ck_tile::Generic2dBlockShape, // Block tile size [M, N] - 1 + // // row, 256 columns + // ck_tile::sequence<1, 256>, // Threads per block [M, N] + // ck_tile::sequence<1, 1> // Vector size [M, N] + // >; - // Define the Kernel type with default policy (naive version) - using Kernel = - ck_tile::ManifoldConstrainedHyperConnection; + // // Define the Problem type + // using Problem = ck_tile::MHCProblem; - // Define the CK Tile version kernel (v2 with proper tiling) - // Use compile-time parameters for B, n, C - using KernelCKTile = ck_tile:: - ManifoldConstrainedHyperConnectionTiled; + // // Define the Kernel type with default policy (naive version) + // using Kernel = + // ck_tile::ManifoldConstrainedHyperConnection; - // Kernel launch configuration - const ck_tile::index_t kBlockSize = Kernel::BlockSize(); - const ck_tile::index_t kGridSize = B; // One block per batch element - constexpr ck_tile::index_t kBlockPerCu = 1; + // // Define the CK Tile version kernel (v2 with proper tiling) + // // Use compile-time parameters for B, n, C + // using KernelCKTile = ck_tile:: + // ManifoldConstrainedHyperConnectionTiled; - std::cout << "Launching MHC kernel (naive version) with:" << std::endl; - std::cout << " Batch size (B): " << B << std::endl; - std::cout << " Expansion factor (n): " << n << std::endl; - std::cout << " Channels per stream (C): " << C << std::endl; - std::cout << " Input dimension (nC): " << nC << std::endl; - std::cout << " Output dimension (2n+n²): " << output_dim << std::endl; - std::cout << " Grid size: " << kGridSize << std::endl; - std::cout << " Block size: " << kBlockSize << std::endl; + // // Kernel launch configuration + // const ck_tile::index_t kBlockSize = Kernel::BlockSize(); + // const ck_tile::index_t kGridSize = B; // One block per batch element + // constexpr ck_tile::index_t kBlockPerCu = 1; - // Get shared memory size - const ck_tile::index_t smem_size = Kernel::GetSmemSize(); - std::cout << " Shared memory size: " << smem_size << " bytes" << std::endl; + // std::cout << "Launching MHC kernel (naive version) with:" << std::endl; + // std::cout << " Batch size (B): " << B << std::endl; + // std::cout << " Expansion factor (n): " << n << std::endl; + // std::cout << " Channels per stream (C): " << C << std::endl; + // std::cout << " Input dimension (nC): " << nC << std::endl; + // std::cout << " Output dimension (2n+n²): " << output_dim << std::endl; + // std::cout << " Grid size: " << kGridSize << std::endl; + // std::cout << " Block size: " << kBlockSize << std::endl; - // Kernel parameters - const float r = 1.0f; - const float alpha_pre = 1.0f; - const float alpha_post = 1.0f; - const float alpha_res = 1.0f; - const float bias = 0.0f; + // // Get shared memory size + // const ck_tile::index_t smem_size = Kernel::GetSmemSize(); + // std::cout << " Shared memory size: " << smem_size << " bytes" << std::endl; - // Kernel launch - ck_tile::launch_kernel( - ck_tile::stream_config{nullptr, false, 0}, - ck_tile::make_kernel(Kernel{}, - kGridSize, - kBlockSize, - smem_size, - static_cast(d_x_mem.GetDeviceBuffer()), - static_cast(d_phi_mem.GetDeviceBuffer()), - static_cast(d_output_mem.GetDeviceBuffer()), - B, - n, - C, - r, - alpha_pre, - alpha_post, - alpha_res, - bias)); + // // Kernel parameters + // const float r = 1.0f; + // const float alpha_pre = 1.0f; + // const float alpha_post = 1.0f; + // const float alpha_res = 1.0f; + // const float bias = 0.0f; - // Copy results back to host - d_output_mem.FromDevice(h_output.data()); + // // Kernel launch + // ck_tile::launch_kernel( + // ck_tile::stream_config{nullptr, false, 0}, + // ck_tile::make_kernel(Kernel{}, + // kGridSize, + // kBlockSize, + // smem_size, + // static_cast(d_x_mem.GetDeviceBuffer()), + // static_cast(d_phi_mem.GetDeviceBuffer()), + // static_cast(d_output_mem.GetDeviceBuffer()), + // B, + // n, + // C, + // r, + // alpha_pre, + // alpha_post, + // alpha_res, + // bias)); - std::cout << "Kernel launched successfully!" << std::endl; + // // Copy results back to host + // d_output_mem.FromDevice(h_output.data()); - // Print output to verify kernel actually modified the tensor - std::cout << "\nOutput tensor (first 2 batches, all " << output_dim - << " elements):" << std::endl; - for(int b = 0; b < std::min(2, B); b++) - { - std::cout << "Batch " << b << ": ["; - for(int i = 0; i < output_dim; i++) - { - std::cout << h_output(b, i); - if(i < output_dim - 1) - std::cout << ", "; - } - std::cout << "]" << std::endl; - } + // std::cout << "Kernel launched successfully!" << std::endl; - // Verify that output is not all zeros (kernel actually ran) - bool has_nonzero = false; - for(int b = 0; b < B && !has_nonzero; b++) - { - for(int i = 0; i < output_dim && !has_nonzero; i++) - { - if(std::abs(h_output(b, i)) > 1e-6f) - { - has_nonzero = true; - } - } - } + // // Print output to verify kernel actually modified the tensor + // std::cout << "\nOutput tensor (first 2 batches, all " << output_dim + // << " elements):" << std::endl; + // for(int b = 0; b < std::min(2, B); b++) + // { + // std::cout << "Batch " << b << ": ["; + // for(int i = 0; i < output_dim; i++) + // { + // std::cout << h_output(b, i); + // if(i < output_dim - 1) + // std::cout << ", "; + // } + // std::cout << "]" << std::endl; + // } - std::cout << "\nNaive kernel output verification: " - << (has_nonzero ? "PASS (non-zero values found)" : "FAIL (all zeros)") - << std::endl; + // // Verify that output is not all zeros (kernel actually ran) + // bool has_nonzero = false; + // for(int b = 0; b < B && !has_nonzero; b++) + // { + // for(int i = 0; i < output_dim && !has_nonzero; i++) + // { + // if(std::abs(h_output(b, i)) > 1e-6f) + // { + // has_nonzero = true; + // } + // } + // } - // Test CK Tile version - std::cout << "\n========================================" << std::endl; - std::cout << "Testing CK Tile version kernel..." << std::endl; - std::cout << "========================================" << std::endl; + // std::cout << "\nNaive kernel output verification: " + // << (has_nonzero ? "PASS (non-zero values found)" : "FAIL (all zeros)") + // << std::endl; - ck_tile::HostTensor h_output_cktile({B, output_dim}); - h_output_cktile.SetZero(); + // // Test CK Tile version + // std::cout << "\n========================================" << std::endl; + // std::cout << "Testing CK Tile version kernel..." << std::endl; + // std::cout << "========================================" << std::endl; - ck_tile::DeviceMem d_output_cktile_mem(h_output_cktile.get_element_space_size_in_bytes()); - d_output_cktile_mem.ToDevice(h_output_cktile.data()); + // ck_tile::HostTensor h_output_cktile({B, output_dim}); + // h_output_cktile.SetZero(); - // Launch CK Tile kernel (B, n, C are template parameters) - ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0}, - ck_tile::make_kernel( - KernelCKTile{}, - kGridSize, - kBlockSize, - smem_size, - static_cast(d_x_mem.GetDeviceBuffer()), - static_cast(d_phi_mem.GetDeviceBuffer()), - static_cast(d_output_cktile_mem.GetDeviceBuffer()), - r, - alpha_pre, - alpha_post, - alpha_res, - bias)); + // ck_tile::DeviceMem + // d_output_cktile_mem(h_output_cktile.get_element_space_size_in_bytes()); + // d_output_cktile_mem.ToDevice(h_output_cktile.data()); - d_output_cktile_mem.FromDevice(h_output_cktile.data()); + // // Launch CK Tile kernel (B, n, C are template parameters) + // ck_tile::launch_kernel(ck_tile::stream_config{nullptr, false, 0}, + // ck_tile::make_kernel( + // KernelCKTile{}, + // kGridSize, + // kBlockSize, + // smem_size, + // static_cast(d_x_mem.GetDeviceBuffer()), + // static_cast(d_phi_mem.GetDeviceBuffer()), + // static_cast(d_output_cktile_mem.GetDeviceBuffer()), + // r, + // alpha_pre, + // alpha_post, + // alpha_res, + // bias)); - std::cout << "\nCK Tile kernel output (first 2 batches):" << std::endl; - for(int b = 0; b < std::min(2, B); b++) - { - std::cout << "Batch " << b << ": ["; - for(int i = 0; i < output_dim; i++) - { - std::cout << h_output_cktile(b, i); - if(i < output_dim - 1) - std::cout << ", "; - } - std::cout << "]" << std::endl; - } + // d_output_cktile_mem.FromDevice(h_output_cktile.data()); - // Compute reference result - ck_tile::HostTensor h_output_ref({B, output_dim}); - h_output_ref.SetZero(); + // std::cout << "\nCK Tile kernel output (first 2 batches):" << std::endl; + // for(int b = 0; b < std::min(2, B); b++) + // { + // std::cout << "Batch " << b << ": ["; + // for(int i = 0; i < output_dim; i++) + // { + // std::cout << h_output_cktile(b, i); + // if(i < output_dim - 1) + // std::cout << ", "; + // } + // std::cout << "]" << std::endl; + // } - std::cout << "\nComputing reference result..." << std::endl; - ck_tile::reference_mhc( - h_x, h_phi, h_output_ref, n, C, r, alpha_pre, alpha_post, alpha_res, bias); + // // Compute reference result + // ck_tile::HostTensor h_output_ref({B, output_dim}); + // h_output_ref.SetZero(); - std::cout << "\nReference output (first 2 batches):" << std::endl; - for(int b = 0; b < std::min(2, B); b++) - { - std::cout << "Batch " << b << ": ["; - for(int i = 0; i < output_dim; i++) - { - std::cout << h_output_ref(b, i); - if(i < output_dim - 1) - std::cout << ", "; - } - std::cout << "]" << std::endl; - } + // std::cout << "\nComputing reference result..." << std::endl; + // ck_tile::reference_mhc( + // h_x, h_phi, h_output_ref, n, C, r, alpha_pre, alpha_post, alpha_res, bias); - // Validate results - const float rtol = 1e-3f; // Relative tolerance - const float atol = 1e-3f; // Absolute tolerance + // std::cout << "\nReference output (first 2 batches):" << std::endl; + // for(int b = 0; b < std::min(2, B); b++) + // { + // std::cout << "Batch " << b << ": ["; + // for(int i = 0; i < output_dim; i++) + // { + // std::cout << h_output_ref(b, i); + // if(i < output_dim - 1) + // std::cout << ", "; + // } + // std::cout << "]" << std::endl; + // } - bool pass_naive = ck_tile::check_err( - h_output, h_output_ref, "Error: Naive MHC output mismatch!", rtol, atol); + // // Validate results + // const float rtol = 1e-3f; // Relative tolerance + // const float atol = 1e-3f; // Absolute tolerance - bool pass_cktile = ck_tile::check_err( - h_output_cktile, h_output_ref, "Error: CK Tile MHC output mismatch!", rtol, atol); + // bool pass_naive = ck_tile::check_err( + // h_output, h_output_ref, "Error: Naive MHC output mismatch!", rtol, atol); - std::cout << "\n========================================" << std::endl; - std::cout << "Final Results:" << std::endl; - std::cout << " Naive kernel: " << (pass_naive ? "PASS" : "FAIL") << std::endl; - std::cout << " CK Tile kernel: " << (pass_cktile ? "PASS" : "FAIL") << std::endl; - std::cout << "========================================" << std::endl; + // bool pass_cktile = ck_tile::check_err( + // h_output_cktile, h_output_ref, "Error: CK Tile MHC output mismatch!", rtol, atol); - EXPECT_TRUE(pass_naive && pass_cktile); - } + // std::cout << "\n========================================" << std::endl; + // std::cout << "Final Results:" << std::endl; + // std::cout << " Naive kernel: " << (pass_naive ? "PASS" : "FAIL") << std::endl; + // std::cout << " CK Tile kernel: " << (pass_cktile ? "PASS" : "FAIL") << std::endl; + // std::cout << "========================================" << std::endl; - // Test with specific batch size (template version with compile-time parameters) - template - void RunBatchSizeTest() + // EXPECT_TRUE(pass_naive && pass_cktile); + // } + + // // Test with specific batch size (template version with compile-time parameters) + // template + // void RunBatchSizeTest() + // { + // const int nC = n * C; // Total input dimension + // const int output_dim = 2 * n + n * n; // 2n + n^2 + + // std::cout << "\n--- Testing batch size B=" << B << " (n=" << n << ", C=" << C << ") ---" + // << std::endl; + + // // Allocate host tensors + // ck_tile::HostTensor h_x({B, nC}); + // ck_tile::HostTensor h_phi({nC, output_dim}); + // ck_tile::HostTensor h_output({B, output_dim}); + + // // Initialize with random data + // ck_tile::FillUniformDistribution{-1.0f, 1.0f}(h_x); + // ck_tile::FillUniformDistribution{-0.5f, 0.5f}(h_phi); + // h_output.SetZero(); + + // // Allocate device memory + // ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + // ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes()); + // ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes()); + + // // Copy data to device + // d_x_mem.ToDevice(h_x.data()); + // d_phi_mem.ToDevice(h_phi.data()); + // d_output_mem.ToDevice(h_output.data()); + + // // Define block shape + // using BlockShape = ck_tile::Generic2dBlockShape, + // ck_tile::sequence<1, 256>, + // ck_tile::sequence<1, 1>>; + + // using Problem = ck_tile::MHCProblem; + + // // Use template parameters for B, n, C (compile-time) + // // This allows better optimization and proper use of store_tile + // using KernelExpansionParallel = ck_tile:: + // ManifoldConstrainedHyperConnectionTiled; + + // const ck_tile::index_t kBlockSize = KernelExpansionParallel::BlockSize(); + // const ck_tile::index_t kGridSize = (output_dim + 15) / 16; + // constexpr ck_tile::index_t kBlockPerCu = 1; + + // // const float r = 1.0f, alpha_pre = 1.0f, alpha_post = 1.0f, alpha_res = 1.0f, bias = + // 0.0f; const float r = 2.0f, alpha_pre = 1.5f, alpha_post = 2.5f, alpha_res = 3.5f, bias + // = 1.5f; + + // // Launch kernel (B, n, C are now template parameters, not runtime) + // ck_tile::launch_kernel( + // ck_tile::stream_config{nullptr, false, 0}, + // ck_tile::make_kernel(KernelExpansionParallel{}, + // kGridSize, + // kBlockSize, + // 0, + // static_cast(d_x_mem.GetDeviceBuffer()), + // static_cast(d_phi_mem.GetDeviceBuffer()), + // static_cast(d_output_mem.GetDeviceBuffer()), + // r, + // alpha_pre, + // alpha_post, + // alpha_res, + // bias)); + + // d_output_mem.FromDevice(h_output.data()); + + // // Compute reference + // ck_tile::HostTensor h_output_ref({B, output_dim}); + // h_output_ref.SetZero(); + // ck_tile::reference_mhc( + // h_x, h_phi, h_output_ref, n, C, r, alpha_pre, alpha_post, alpha_res, bias); + + // // Validate + // bool pass = + // ck_tile::check_err(h_output, h_output_ref, "Error: Batch size mismatch!", 1e-3f, + // 1e-3f); + + // std::cout << " Result: " << (pass ? "PASS" : "FAIL") << std::endl; + + // if(!pass) + // { + // // Print first few values for debugging + // std::cout << " First batch kernel output: ["; + // for(int i = 0; i < std::min(4, output_dim); i++) + // { + // std::cout << h_output(0, i); + // if(i < std::min(4, output_dim) - 1) + // std::cout << ", "; + // } + // std::cout << " ...]" << std::endl; + + // std::cout << " First batch reference: ["; + // for(int i = 0; i < std::min(4, output_dim); i++) + // { + // std::cout << h_output_ref(0, i); + // if(i < std::min(4, output_dim) - 1) + // std::cout << ", "; + // } + // std::cout << " ...]" << std::endl; + // } + + // EXPECT_TRUE(pass); + // } + + // // Test with specific batch size and custom activation function + // template + // void RunBatchSizeTestWithActivation() + // { + // const int nC = n * C; // Total input dimension + // const int output_dim = 2 * n + n * n; // 2n + n^2 + + // std::cout << "\n--- Testing batch size B=" << B << " (n=" << n << ", C=" << C + // << ") with activation: " << ActivationFunc::name << " ---" << std::endl; + + // // Allocate host tensors + // ck_tile::HostTensor h_x({B, nC}); + // ck_tile::HostTensor h_phi({nC, output_dim}); + // ck_tile::HostTensor h_output({B, output_dim}); + + // // Initialize with random data + // ck_tile::FillUniformDistribution{-1.0f, 1.0f}(h_x); + // ck_tile::FillUniformDistribution{-0.5f, 0.5f}(h_phi); + // h_output.SetZero(); + + // // Allocate device memory + // ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + // ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes()); + // ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes()); + + // // Copy data to device + // d_x_mem.ToDevice(h_x.data()); + // d_phi_mem.ToDevice(h_phi.data()); + // d_output_mem.ToDevice(h_output.data()); + + // // Define block shape + // using BlockShape = ck_tile::Generic2dBlockShape, + // ck_tile::sequence<1, 256>, + // ck_tile::sequence<1, 1>>; + + // using Problem = ck_tile::MHCProblem; + + // // Use template parameters for B, n, C, and Activation (compile-time) + // using KernelExpansionParallel = + // ck_tile::ManifoldConstrainedHyperConnectionTiled; + + // const ck_tile::index_t kBlockSize = KernelExpansionParallel::BlockSize(); + // const ck_tile::index_t kGridSize = (output_dim + 15) / 16; + // constexpr ck_tile::index_t kBlockPerCu = 1; + + // const float r = 2.0f, alpha_pre = 1.5f, alpha_post = 2.5f, alpha_res = 3.5f, bias = 1.5f; + + // // Launch kernel + // ck_tile::launch_kernel( + // ck_tile::stream_config{nullptr, false, 0}, + // ck_tile::make_kernel(KernelExpansionParallel{}, + // kGridSize, + // kBlockSize, + // 0, + // static_cast(d_x_mem.GetDeviceBuffer()), + // static_cast(d_phi_mem.GetDeviceBuffer()), + // static_cast(d_output_mem.GetDeviceBuffer()), + // r, + // alpha_pre, + // alpha_post, + // alpha_res, + // bias)); + + // d_output_mem.FromDevice(h_output.data()); + + // // Compute reference with the same activation function + // ck_tile::HostTensor h_output_ref({B, output_dim}); + // h_output_ref.SetZero(); + // ck_tile::reference_mhc(h_x, + // h_phi, + // h_output_ref, + // n, + // C, + // r, + // alpha_pre, + // alpha_post, + // alpha_res, + // bias, + // ActivationFunc{}); + + // // Validate + // bool pass = ck_tile::check_err( + // h_output, h_output_ref, "Error: Activation function mismatch!", 1e-3f, 1e-3f); + + // std::cout << " Result: " << (pass ? "PASS" : "FAIL") << std::endl; + + // if(!pass) + // { + // // Print first few values for debugging + // std::cout << " First batch kernel output: ["; + // for(int i = 0; i < std::min(8, output_dim); i++) + // { + // std::cout << h_output(0, i); + // if(i < std::min(8, output_dim) - 1) + // std::cout << ", "; + // } + // std::cout << " ...]" << std::endl; + + // std::cout << " First batch reference: ["; + // for(int i = 0; i < std::min(8, output_dim); i++) + // { + // std::cout << h_output_ref(0, i); + // if(i < std::min(8, output_dim) - 1) + // std::cout << ", "; + // } + // std::cout << " ...]" << std::endl; + // } + + // EXPECT_TRUE(pass); + // } + + // // Test with multiple arbitrary batch sizes + // void RunArbitraryBatchSizeTest() + // { + // std::cout << "\n========================================" << std::endl; + // std::cout << "Testing Arbitrary Batch Sizes..." << std::endl; + // std::cout << " Expansion factor (n): 4" << std::endl; + // std::cout << " Channels per stream (C): 64" << std::endl; + // std::cout << " Output dimension: 24" << std::endl; + // std::cout << "========================================" << std::endl; + + // // Call template versions with compile-time parameters + // RunBatchSizeTest<1, 4, 64>(); + // RunBatchSizeTest<7, 4, 64>(); + // RunBatchSizeTest<15, 4, 64>(); + // RunBatchSizeTest<16, 4, 64>(); + // RunBatchSizeTest<17, 4, 64>(); + // RunBatchSizeTest<23, 4, 64>(); + // RunBatchSizeTest<32, 4, 64>(); + // RunBatchSizeTest<33, 4, 64>(); + // RunBatchSizeTest<47, 4, 64>(); + // RunBatchSizeTest<48, 4, 64>(); + // RunBatchSizeTest<64, 4, 64>(); + + // std::cout << "\n========================================" << std::endl; + // std::cout << "Overall Result: ALL TESTS COMPLETED" << std::endl; + // std::cout << "========================================" << std::endl; + // } + + // // New test: Parallelize by expansion factor (n) instead of batch + // void RunExpansionParallelTest() + // { + // // Test parameters - realistic sizes for BlockGemm + // const int B = 16; // Batch size (M dimension in GEMM) + // const int n = 4; // Expansion rate + // const int C = 64; // Output layer dim (smaller for testing) + // const int nC = n * C; // Total input dimension = 256 (K dimension in + // GEMM) const int output_dim = 2 * n + n * n; // 2n + n^2 = 24 for n=4 + + // std::cout << "\n========================================" << std::endl; + // std::cout << "Testing Expansion-Parallel MHC kernel..." << std::endl; + // std::cout << " Batch size (B): " << B << std::endl; + // std::cout << " Expansion factor (n): " << n << std::endl; + // std::cout << " Channels per stream (C): " << C << std::endl; + // std::cout << " Grid size: " << output_dim << " (one block per expansion stream)" + // << std::endl; + // std::cout << "========================================" << std::endl; + + // // Allocate host tensors + // ck_tile::HostTensor h_x({B, nC}); + // ck_tile::HostTensor h_phi({nC, output_dim}); + // ck_tile::HostTensor h_output({B, output_dim}); + + // // Initialize with random data + // ck_tile::FillUniformDistribution{-1.0f, 1.0f}(h_x); + // ck_tile::FillUniformDistribution{-0.5f, 0.5f}(h_phi); + // h_output.SetZero(); + + // // Allocate device memory + // ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); + // ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes()); + // ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes()); + + // // Copy data to device + // d_x_mem.ToDevice(h_x.data()); + // d_phi_mem.ToDevice(h_phi.data()); + // d_output_mem.ToDevice(h_output.data()); + + // // Define block shape + // using BlockShape = ck_tile::Generic2dBlockShape, + // ck_tile::sequence<1, 256>, + // ck_tile::sequence<1, 1>>; + + // using Problem = ck_tile::MHCProblem; + // using KernelExpansionParallel = ck_tile:: + // ManifoldConstrainedHyperConnectionTiled; + + // const ck_tile::index_t kBlockSize = KernelExpansionParallel::BlockSize(); + // // Grid size: one block per 16 output columns (since BlockGemm processes N=16) + // const ck_tile::index_t kGridSize = (output_dim + 15) / 16; + // constexpr ck_tile::index_t kBlockPerCu = 1; + + // const float r = 1.0f, alpha_pre = 1.0f, alpha_post = 1.0f, alpha_res = 1.0f, bias = 0.0f; + + // // Launch expansion-parallel kernel (B, n, C are template parameters) + // ck_tile::launch_kernel( + // ck_tile::stream_config{nullptr, false, 0}, + // ck_tile::make_kernel(KernelExpansionParallel{}, + // kGridSize, + // kBlockSize, + // 0, + // static_cast(d_x_mem.GetDeviceBuffer()), + // static_cast(d_phi_mem.GetDeviceBuffer()), + // static_cast(d_output_mem.GetDeviceBuffer()), + // r, + // alpha_pre, + // alpha_post, + // alpha_res, + // bias)); + + // d_output_mem.FromDevice(h_output.data()); + + // // Print kernel output to debug + // std::cout << "\nKernel output (first 2 batches, first 8 elements):" << std::endl; + // for(int b = 0; b < std::min(2, B); b++) + // { + // std::cout << "Batch " << b << ": ["; + // for(int i = 0; i < std::min(8, output_dim); i++) + // { + // std::cout << h_output(b, i); + // if(i < std::min(8, output_dim) - 1) + // std::cout << ", "; + // } + // std::cout << " ...]" << std::endl; + // } + + // // Compute reference + // ck_tile::HostTensor h_output_ref({B, output_dim}); + // h_output_ref.SetZero(); + // ck_tile::reference_mhc( + // h_x, h_phi, h_output_ref, n, C, r, alpha_pre, alpha_post, alpha_res, bias); + + // // Print reference output + // std::cout << "\nReference output (first 2 batches, first 8 elements):" << std::endl; + // for(int b = 0; b < std::min(2, B); b++) + // { + // std::cout << "Batch " << b << ": ["; + // for(int i = 0; i < std::min(8, output_dim); i++) + // { + // std::cout << h_output_ref(b, i); + // if(i < std::min(8, output_dim) - 1) + // std::cout << ", "; + // } + // std::cout << " ...]" << std::endl; + // } + + // // Validate + // bool pass = ck_tile::check_err( + // h_output, h_output_ref, "Error: Expansion-parallel MHC mismatch!", 1e-3f, 1e-3f); + + // std::cout << "Expansion-parallel kernel: " << (pass ? "PASS" : "FAIL") << std::endl; + // EXPECT_TRUE(pass); + // } + + template + void RunGemmPipeline() { const int nC = n * C; // Total input dimension - const int output_dim = 2 * n + n * n; // 2n + n^2 + const int output_dim = 2 * n + n * n; // 2n + n^2 = 24 - std::cout << "\n--- Testing batch size B=" << B << " (n=" << n << ", C=" << C << ") ---" - << std::endl; + // using ActivationFunc = ck_tile::element_wise::Sigmoid; + + std::cout << "\n--- Testing MHC Kernel V3 with B=" << B << " (n=" << n << ", C=" << C + << ") ---" << std::endl; + std::cout << "Output dimension: " << output_dim << std::endl; // Allocate host tensors ck_tile::HostTensor h_x({B, nC}); @@ -303,35 +679,53 @@ class TestCkTileMHC : public ::testing::Test d_phi_mem.ToDevice(h_phi.data()); d_output_mem.ToDevice(h_output.data()); - // Define block shape - using BlockShape = ck_tile::Generic2dBlockShape, - ck_tile::sequence<1, 256>, + // Define block shape - must match BlockGemmShape thread count (2 warps × 64 = 128 threads) + using BlockShape = ck_tile::Generic2dBlockShape, + ck_tile::sequence<1, 128>, ck_tile::sequence<1, 1>>; using Problem = ck_tile::MHCProblem; - // Use template parameters for B, n, C (compile-time) - // This allows better optimization and proper use of store_tile - using KernelExpansionParallel = ck_tile:: - ManifoldConstrainedHyperConnectionTiled; + // V3 kernel with 2D tiling + constexpr ck_tile::index_t kMTile = 64; // Batch tile + constexpr ck_tile::index_t kNTile = 32; // Output tile (exactly covers 24 outputs for n=4) + constexpr ck_tile::index_t kKTile = + 8; // K tile for C dimension (must match BlockGemmShape::kK) + + using KernelV3 = ck_tile:: + MHCKernelV3; + + const ck_tile::index_t kBlockSize = KernelV3::BlockSize(); + + // 2D grid: (batch / kMTile) × (output_dim / kNTile) + auto grid_size = KernelV3::GetGridSize(B, output_dim); + const ck_tile::index_t kGridSize = + grid_size.at(ck_tile::number<0>{}) * grid_size.at(ck_tile::number<1>{}); + + std::cout << "Grid configuration: " << grid_size.at(ck_tile::number<0>{}) << " × " + << grid_size.at(ck_tile::number<1>{}) << " = " << kGridSize << " blocks" + << std::endl; + std::cout << "Block size: " << kBlockSize << " threads" << std::endl; + std::cout << "Shared memory: " << KernelV3::GetSmemSize() << " bytes" << std::endl; - const ck_tile::index_t kBlockSize = KernelExpansionParallel::BlockSize(); - const ck_tile::index_t kGridSize = (output_dim + 15) / 16; constexpr ck_tile::index_t kBlockPerCu = 1; - // const float r = 1.0f, alpha_pre = 1.0f, alpha_post = 1.0f, alpha_res = 1.0f, bias = 0.0f; const float r = 2.0f, alpha_pre = 1.5f, alpha_post = 2.5f, alpha_res = 3.5f, bias = 1.5f; - // Launch kernel (B, n, C are now template parameters, not runtime) + // Launch kernel ck_tile::launch_kernel( ck_tile::stream_config{nullptr, false, 0}, - ck_tile::make_kernel(KernelExpansionParallel{}, + ck_tile::make_kernel(KernelV3{}, kGridSize, kBlockSize, - 0, + KernelV3::GetSmemSize(), static_cast(d_x_mem.GetDeviceBuffer()), static_cast(d_phi_mem.GetDeviceBuffer()), static_cast(d_output_mem.GetDeviceBuffer()), + B, + nC, + output_dim, + n, r, alpha_pre, alpha_post, @@ -343,117 +737,6 @@ class TestCkTileMHC : public ::testing::Test // Compute reference ck_tile::HostTensor h_output_ref({B, output_dim}); h_output_ref.SetZero(); - ck_tile::reference_mhc( - h_x, h_phi, h_output_ref, n, C, r, alpha_pre, alpha_post, alpha_res, bias); - - // Validate - bool pass = - ck_tile::check_err(h_output, h_output_ref, "Error: Batch size mismatch!", 1e-3f, 1e-3f); - - std::cout << " Result: " << (pass ? "PASS" : "FAIL") << std::endl; - - if(!pass) - { - // Print first few values for debugging - std::cout << " First batch kernel output: ["; - for(int i = 0; i < std::min(4, output_dim); i++) - { - std::cout << h_output(0, i); - if(i < std::min(4, output_dim) - 1) - std::cout << ", "; - } - std::cout << " ...]" << std::endl; - - std::cout << " First batch reference: ["; - for(int i = 0; i < std::min(4, output_dim); i++) - { - std::cout << h_output_ref(0, i); - if(i < std::min(4, output_dim) - 1) - std::cout << ", "; - } - std::cout << " ...]" << std::endl; - } - - EXPECT_TRUE(pass); - } - - // Test with specific batch size and custom activation function - template - void RunBatchSizeTestWithActivation() - { - const int nC = n * C; // Total input dimension - const int output_dim = 2 * n + n * n; // 2n + n^2 - - std::cout << "\n--- Testing batch size B=" << B << " (n=" << n << ", C=" << C - << ") with activation: " << ActivationFunc::name << " ---" << std::endl; - - // Allocate host tensors - ck_tile::HostTensor h_x({B, nC}); - ck_tile::HostTensor h_phi({nC, output_dim}); - ck_tile::HostTensor h_output({B, output_dim}); - - // Initialize with random data - ck_tile::FillUniformDistribution{-1.0f, 1.0f}(h_x); - ck_tile::FillUniformDistribution{-0.5f, 0.5f}(h_phi); - h_output.SetZero(); - - // Allocate device memory - ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); - ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes()); - ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes()); - - // Copy data to device - d_x_mem.ToDevice(h_x.data()); - d_phi_mem.ToDevice(h_phi.data()); - d_output_mem.ToDevice(h_output.data()); - - // Define block shape - using BlockShape = ck_tile::Generic2dBlockShape, - ck_tile::sequence<1, 256>, - ck_tile::sequence<1, 1>>; - - using Problem = ck_tile::MHCProblem; - - // Use template parameters for B, n, C, and Activation (compile-time) - using KernelExpansionParallel = - ck_tile::ManifoldConstrainedHyperConnectionTiled; - - const ck_tile::index_t kBlockSize = KernelExpansionParallel::BlockSize(); - const ck_tile::index_t kGridSize = (output_dim + 15) / 16; - constexpr ck_tile::index_t kBlockPerCu = 1; - - const float r = 2.0f, alpha_pre = 1.5f, alpha_post = 2.5f, alpha_res = 3.5f, bias = 1.5f; - - // Launch kernel - ck_tile::launch_kernel( - ck_tile::stream_config{nullptr, false, 0}, - ck_tile::make_kernel(KernelExpansionParallel{}, - kGridSize, - kBlockSize, - 0, - static_cast(d_x_mem.GetDeviceBuffer()), - static_cast(d_phi_mem.GetDeviceBuffer()), - static_cast(d_output_mem.GetDeviceBuffer()), - r, - alpha_pre, - alpha_post, - alpha_res, - bias)); - - d_output_mem.FromDevice(h_output.data()); - - // Compute reference with the same activation function - ck_tile::HostTensor h_output_ref({B, output_dim}); - h_output_ref.SetZero(); ck_tile::reference_mhc(h_x, h_phi, h_output_ref, @@ -468,175 +751,9 @@ class TestCkTileMHC : public ::testing::Test // Validate bool pass = ck_tile::check_err( - h_output, h_output_ref, "Error: Activation function mismatch!", 1e-3f, 1e-3f); + h_output, h_output_ref, "Error: MHC V3 kernel output mismatch!", 1e-3f, 1e-3f); - std::cout << " Result: " << (pass ? "PASS" : "FAIL") << std::endl; - - if(!pass) - { - // Print first few values for debugging - std::cout << " First batch kernel output: ["; - for(int i = 0; i < std::min(8, output_dim); i++) - { - std::cout << h_output(0, i); - if(i < std::min(8, output_dim) - 1) - std::cout << ", "; - } - std::cout << " ...]" << std::endl; - - std::cout << " First batch reference: ["; - for(int i = 0; i < std::min(8, output_dim); i++) - { - std::cout << h_output_ref(0, i); - if(i < std::min(8, output_dim) - 1) - std::cout << ", "; - } - std::cout << " ...]" << std::endl; - } - - EXPECT_TRUE(pass); - } - - // Test with multiple arbitrary batch sizes - void RunArbitraryBatchSizeTest() - { - std::cout << "\n========================================" << std::endl; - std::cout << "Testing Arbitrary Batch Sizes..." << std::endl; - std::cout << " Expansion factor (n): 4" << std::endl; - std::cout << " Channels per stream (C): 64" << std::endl; - std::cout << " Output dimension: 24" << std::endl; - std::cout << "========================================" << std::endl; - - // Call template versions with compile-time parameters - RunBatchSizeTest<1, 4, 64>(); - RunBatchSizeTest<7, 4, 64>(); - RunBatchSizeTest<15, 4, 64>(); - RunBatchSizeTest<16, 4, 64>(); - RunBatchSizeTest<17, 4, 64>(); - RunBatchSizeTest<23, 4, 64>(); - RunBatchSizeTest<32, 4, 64>(); - RunBatchSizeTest<33, 4, 64>(); - RunBatchSizeTest<47, 4, 64>(); - RunBatchSizeTest<48, 4, 64>(); - RunBatchSizeTest<64, 4, 64>(); - - std::cout << "\n========================================" << std::endl; - std::cout << "Overall Result: ALL TESTS COMPLETED" << std::endl; - std::cout << "========================================" << std::endl; - } - - // New test: Parallelize by expansion factor (n) instead of batch - void RunExpansionParallelTest() - { - // Test parameters - realistic sizes for BlockGemm - const int B = 16; // Batch size (M dimension in GEMM) - const int n = 4; // Expansion rate - const int C = 64; // Output layer dim (smaller for testing) - const int nC = n * C; // Total input dimension = 256 (K dimension in GEMM) - const int output_dim = 2 * n + n * n; // 2n + n^2 = 24 for n=4 - - std::cout << "\n========================================" << std::endl; - std::cout << "Testing Expansion-Parallel MHC kernel..." << std::endl; - std::cout << " Batch size (B): " << B << std::endl; - std::cout << " Expansion factor (n): " << n << std::endl; - std::cout << " Channels per stream (C): " << C << std::endl; - std::cout << " Grid size: " << output_dim << " (one block per expansion stream)" - << std::endl; - std::cout << "========================================" << std::endl; - - // Allocate host tensors - ck_tile::HostTensor h_x({B, nC}); - ck_tile::HostTensor h_phi({nC, output_dim}); - ck_tile::HostTensor h_output({B, output_dim}); - - // Initialize with random data - ck_tile::FillUniformDistribution{-1.0f, 1.0f}(h_x); - ck_tile::FillUniformDistribution{-0.5f, 0.5f}(h_phi); - h_output.SetZero(); - - // Allocate device memory - ck_tile::DeviceMem d_x_mem(h_x.get_element_space_size_in_bytes()); - ck_tile::DeviceMem d_phi_mem(h_phi.get_element_space_size_in_bytes()); - ck_tile::DeviceMem d_output_mem(h_output.get_element_space_size_in_bytes()); - - // Copy data to device - d_x_mem.ToDevice(h_x.data()); - d_phi_mem.ToDevice(h_phi.data()); - d_output_mem.ToDevice(h_output.data()); - - // Define block shape - using BlockShape = ck_tile::Generic2dBlockShape, - ck_tile::sequence<1, 256>, - ck_tile::sequence<1, 1>>; - - using Problem = ck_tile::MHCProblem; - using KernelExpansionParallel = ck_tile:: - ManifoldConstrainedHyperConnectionTiled; - - const ck_tile::index_t kBlockSize = KernelExpansionParallel::BlockSize(); - // Grid size: one block per 16 output columns (since BlockGemm processes N=16) - const ck_tile::index_t kGridSize = (output_dim + 15) / 16; - constexpr ck_tile::index_t kBlockPerCu = 1; - - const float r = 1.0f, alpha_pre = 1.0f, alpha_post = 1.0f, alpha_res = 1.0f, bias = 0.0f; - - // Launch expansion-parallel kernel (B, n, C are template parameters) - ck_tile::launch_kernel( - ck_tile::stream_config{nullptr, false, 0}, - ck_tile::make_kernel(KernelExpansionParallel{}, - kGridSize, - kBlockSize, - 0, - static_cast(d_x_mem.GetDeviceBuffer()), - static_cast(d_phi_mem.GetDeviceBuffer()), - static_cast(d_output_mem.GetDeviceBuffer()), - r, - alpha_pre, - alpha_post, - alpha_res, - bias)); - - d_output_mem.FromDevice(h_output.data()); - - // Print kernel output to debug - std::cout << "\nKernel output (first 2 batches, first 8 elements):" << std::endl; - for(int b = 0; b < std::min(2, B); b++) - { - std::cout << "Batch " << b << ": ["; - for(int i = 0; i < std::min(8, output_dim); i++) - { - std::cout << h_output(b, i); - if(i < std::min(8, output_dim) - 1) - std::cout << ", "; - } - std::cout << " ...]" << std::endl; - } - - // Compute reference - ck_tile::HostTensor h_output_ref({B, output_dim}); - h_output_ref.SetZero(); - ck_tile::reference_mhc( - h_x, h_phi, h_output_ref, n, C, r, alpha_pre, alpha_post, alpha_res, bias); - - // Print reference output - std::cout << "\nReference output (first 2 batches, first 8 elements):" << std::endl; - for(int b = 0; b < std::min(2, B); b++) - { - std::cout << "Batch " << b << ": ["; - for(int i = 0; i < std::min(8, output_dim); i++) - { - std::cout << h_output_ref(b, i); - if(i < std::min(8, output_dim) - 1) - std::cout << ", "; - } - std::cout << " ...]" << std::endl; - } - - // Validate - bool pass = ck_tile::check_err( - h_output, h_output_ref, "Error: Expansion-parallel MHC mismatch!", 1e-3f, 1e-3f); - - std::cout << "Expansion-parallel kernel: " << (pass ? "PASS" : "FAIL") << std::endl; + std::cout << "Result: " << (pass ? "PASS" : "FAIL") << std::endl; EXPECT_TRUE(pass); }