From b161cd94ccdaebadaab0fd5dc68bb0a22e091dff Mon Sep 17 00:00:00 2001 From: Mohsen Saffari Date: Thu, 16 Oct 2025 09:24:39 +0000 Subject: [PATCH] Refactore the compute reference batched contraction to manage stride-aware calculation and some code cleanings --- .../contraction_utils.hpp | 28 +- .../run_batched_contraction_example.inc | 32 +- .../reference_batched_contraction.hpp | 393 ++++++++---------- 3 files changed, 213 insertions(+), 240 deletions(-) diff --git a/example/ck_tile/41_batched_contraction/contraction_utils.hpp b/example/ck_tile/41_batched_contraction/contraction_utils.hpp index d50d836727..9b7d43472e 100644 --- a/example/ck_tile/41_batched_contraction/contraction_utils.hpp +++ b/example/ck_tile/41_batched_contraction/contraction_utils.hpp @@ -48,30 +48,34 @@ void print_help(const char* program_name) std::cout << "Batched Tensor Contraction with element-wise fusion\n"; std::cout << "E[G,M,N] = element_wise_op(contraction(A[G,M,K], B[G,N,K]), D0, D1, ...)\n"; std::cout << "(Supports multiple D tensors with configurable element-wise operations)\n\n"; - + std::cout << "Usage: " << program_name << " [OPTIONS]\n\n"; - + std::cout << "Dimension Arguments (comma-separated, no spaces):\n"; std::cout << " -g_dims= Batch dimensions (default: \"1,2\")\n"; std::cout << " -m_dims= M (row) dimensions (default: \"4,256\")\n"; std::cout << " -n_dims= N (column) dimensions (default: \"16,128\")\n"; std::cout << " -k_dims= K (contract) dims (default: \"64\")\n\n"; - + std::cout << "Layout Arguments:\n"; - std::cout << " -a_layout= A tensor layout (R=Row-major, C=Column-major, default: \"R\")\n"; + std::cout + << " -a_layout= A tensor layout (R=Row-major, C=Column-major, default: \"R\")\n"; std::cout << " -b_layout= B tensor layout (default: \"C\")\n"; std::cout << " -e_layout= E tensor layout (default: \"R\")\n\n"; - + std::cout << "Examples:\n"; std::cout << " Single batch (12 batches of 256×128):\n"; - std::cout << " " << program_name << " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n"; - + std::cout << " " << program_name + << " -g_dims=\"12\" -m_dims=\"256\" -n_dims=\"128\" -k_dims=\"64\"\n\n"; + std::cout << " 2D batch grid (2×3=6 batches):\n"; - std::cout << " " << program_name << " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n"; - + std::cout << " " << program_name + << " -g_dims=\"2,3\" -m_dims=\"128\" -n_dims=\"128\" -k_dims=\"64\"\n\n"; + std::cout << " Multi-dimensional (flattened to M=128, N=128, K=128):\n"; - std::cout << " " << program_name << " -g_dims=\"4\" -m_dims=\"8,16\" -n_dims=\"32,4\" -k_dims=\"16,8\"\n\n"; - + std::cout << " " << program_name + << " -g_dims=\"4\" -m_dims=\"8,16\" -n_dims=\"32,4\" -k_dims=\"16,8\"\n\n"; + std::cout << "Other Options:\n"; std::cout << " -v=<0|1> Validation (0=off, 1=on, default: 1)\n"; std::cout << " -split_k= Split-K value (default: 1)\n"; @@ -93,7 +97,7 @@ auto create_args(int argc, char* argv[]) std::exit(0); } } - + ck_tile::ArgParser arg_parser; arg_parser.insert("m_dims", "4,256", "M dimensions separated by comma (e.g., '16,32' for 2D M)") .insert("n_dims", "16,128", "N dimensions separated by comma (e.g., '32,32' for 2D N)") diff --git a/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc b/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc index 9bc09a6c9c..aa3b91c6a0 100644 --- a/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc +++ b/example/ck_tile/41_batched_contraction/run_batched_contraction_example.inc @@ -201,11 +201,14 @@ int run_batched_contraction_example_with_layouts( ck_tile::HostTensor<::BDataType> b_full_dims_host(b_desc); ck_tile::HostTensor<::EDataType> e_full_dims_host(e_desc); - std::vector> ds_full_dims_host; - for(int d = 0; d < NumDTensor; ++d) - { - ds_full_dims_host.emplace_back(ck_tile::HostTensor<::DDataType>(ds_descs[d])); - } + // Helper to construct array of HostTensors using index_sequence + auto make_ds_host_tensors = [](const auto& descs, + std::index_sequence) { + return std::array, sizeof...(Is)>{ + ck_tile::HostTensor<::DDataType>(descs[Is])...}; + }; + + auto ds_full_dims_host = make_ds_host_tensors(ds_descs, std::make_index_sequence{}); ck_tile::FillUniformDistribution<::ADataType>{-5.f, 5.f, std::nullopt}(a_full_dims_host); ck_tile::FillUniformDistribution<::BDataType>{-5.f, 5.f, std::nullopt}(b_full_dims_host); @@ -316,12 +319,13 @@ int run_batched_contraction_example_with_layouts( auto start_time = std::chrono::high_resolution_clock::now(); - calculate_reference_flat_indexing(a_full_dims_host, + compute_reference_batched_contraction(a_full_dims_host, b_full_dims_host, ds_full_dims_host, e_full_dims_host_ref, @@ -329,7 +333,11 @@ int run_batched_contraction_example_with_layouts( M_total, N_total, K_total, - CDEElementWise{}); + CDEElementWise{}, + G_dims, + M_dims, + N_dims, + K_dims); auto end_time = std::chrono::high_resolution_clock::now(); auto duration = diff --git a/include/ck_tile/host/reference/reference_batched_contraction.hpp b/include/ck_tile/host/reference/reference_batched_contraction.hpp index 1ce071969c..324ff06ef7 100644 --- a/include/ck_tile/host/reference/reference_batched_contraction.hpp +++ b/include/ck_tile/host/reference/reference_batched_contraction.hpp @@ -11,110 +11,210 @@ namespace ck_tile { +// Helper to apply elementwise operation with variable number of D tensors +template +struct ApplyCDEElementWise +{ + template + CK_TILE_HOST_DEVICE static void apply(EDataType& result, + AccDataType sum, + const CDEElementWise& cde_elementwise, + DValues... d_vals) + { + if constexpr(sizeof...(DValues) == 0) + { + result = static_cast(sum); + } + else + { + cde_elementwise( + result, ck_tile::type_convert(sum), ck_tile::type_convert(d_vals)...); + } + } +}; + +// Helper to extract D values at a given offset using index sequence +template > +struct ExtractDValues; + +template +struct ExtractDValues> +{ + template + CK_TILE_HOST static void + apply_at_offset(EDataType& result, + AccDataType sum, + const CDEElementWise& cde_elementwise, + const std::array, NumDTensor>& ds_tensors, + std::size_t offset) + { + ApplyCDEElementWise::apply( + result, sum, cde_elementwise, ds_tensors[Is].mData[offset]...); + } +}; + template + typename CDEElementWise, + ck_tile::index_t NumDTensor> -void calculate_reference_flat_indexing( +void compute_reference_batched_contraction( const ck_tile::HostTensor& a_full_dims, const ck_tile::HostTensor& b_full_dims, - const std::vector>& ds_full_dims_host, + const std::array, NumDTensor>& ds_full_dims_host, ck_tile::HostTensor& e_full_dims_host_ref, ck_tile::index_t G_total, ck_tile::index_t M_total, ck_tile::index_t N_total, ck_tile::index_t K_total, - const CDEElementWise& cde_elementwise) + const CDEElementWise& cde_elementwise, + const std::vector& G_dims, + const std::vector& M_dims, + const std::vector& N_dims, + const std::vector& K_dims) { - std::cout << "Calculating reference using optimized flat indexing with parallel processing..." + std::cout << "Calculating reference using stride-aware indexing with parallel processing..." << std::endl; - // Parallel computation over G and M dimensions using pattern from reference_batched_gemm.hpp + // Extract stride information from tensor descriptors + const auto a_strides = a_full_dims.get_strides(); + const auto b_strides = b_full_dims.get_strides(); + const auto e_strides = e_full_dims_host_ref.get_strides(); + + const ck_tile::index_t num_g_dims = G_dims.size(); + const ck_tile::index_t num_m_dims = M_dims.size(); + const ck_tile::index_t num_n_dims = N_dims.size(); + const ck_tile::index_t num_k_dims = K_dims.size(); + + // Helper lambda to compute linear index from flat indices using strides + auto compute_a_offset = [&](ck_tile::index_t g_flat, + ck_tile::index_t m_flat, + ck_tile::index_t k_flat) -> std::size_t { + std::size_t offset = 0; + + // Decode G dimensions + ck_tile::index_t temp = g_flat; + for(ck_tile::index_t i = num_g_dims - 1; i >= 0; --i) + { + offset += (temp % G_dims[i]) * a_strides[i]; + temp /= G_dims[i]; + } + + // Decode M dimensions + temp = m_flat; + for(ck_tile::index_t i = num_m_dims - 1; i >= 0; --i) + { + offset += (temp % M_dims[i]) * a_strides[num_g_dims + i]; + temp /= M_dims[i]; + } + + // Decode K dimensions + temp = k_flat; + for(ck_tile::index_t i = num_k_dims - 1; i >= 0; --i) + { + offset += (temp % K_dims[i]) * a_strides[num_g_dims + num_m_dims + i]; + temp /= K_dims[i]; + } + + return offset; + }; + + auto compute_b_offset = [&](ck_tile::index_t g_flat, + ck_tile::index_t n_flat, + ck_tile::index_t k_flat) -> std::size_t { + std::size_t offset = 0; + + // Decode G dimensions + ck_tile::index_t temp = g_flat; + for(ck_tile::index_t i = num_g_dims - 1; i >= 0; --i) + { + offset += (temp % G_dims[i]) * b_strides[i]; + temp /= G_dims[i]; + } + + // Decode N dimensions + temp = n_flat; + for(ck_tile::index_t i = num_n_dims - 1; i >= 0; --i) + { + offset += (temp % N_dims[i]) * b_strides[num_g_dims + i]; + temp /= N_dims[i]; + } + + // Decode K dimensions + temp = k_flat; + for(ck_tile::index_t i = num_k_dims - 1; i >= 0; --i) + { + offset += (temp % K_dims[i]) * b_strides[num_g_dims + num_n_dims + i]; + temp /= K_dims[i]; + } + + return offset; + }; + + auto compute_e_offset = [&](ck_tile::index_t g_flat, + ck_tile::index_t m_flat, + ck_tile::index_t n_flat) -> std::size_t { + std::size_t offset = 0; + + // Decode G dimensions + ck_tile::index_t temp = g_flat; + for(ck_tile::index_t i = num_g_dims - 1; i >= 0; --i) + { + offset += (temp % G_dims[i]) * e_strides[i]; + temp /= G_dims[i]; + } + + // Decode M dimensions + temp = m_flat; + for(ck_tile::index_t i = num_m_dims - 1; i >= 0; --i) + { + offset += (temp % M_dims[i]) * e_strides[num_g_dims + i]; + temp /= M_dims[i]; + } + + // Decode N dimensions + temp = n_flat; + for(ck_tile::index_t i = num_n_dims - 1; i >= 0; --i) + { + offset += (temp % N_dims[i]) * e_strides[num_g_dims + num_m_dims + i]; + temp /= N_dims[i]; + } + + return offset; + }; + + // Parallel computation over G and M dimensions auto f_gm = [&](auto g_flat, auto m_flat) { for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat) { AccDataType sum = 0; - // Compute dot product over K dimension + // Compute dot product over K dimension using stride-aware indexing for(ck_tile::index_t k_flat = 0; k_flat < K_total; ++k_flat) { - auto a_val = - a_full_dims.mData[g_flat * M_total * K_total + m_flat * K_total + k_flat]; - auto b_val = - b_full_dims.mData[g_flat * N_total * K_total + n_flat * K_total + k_flat]; + const std::size_t a_offset = compute_a_offset(g_flat, m_flat, k_flat); + const std::size_t b_offset = compute_b_offset(g_flat, n_flat, k_flat); + + auto a_val = a_full_dims.mData[a_offset]; + auto b_val = b_full_dims.mData[b_offset]; sum += static_cast(a_val) * static_cast(b_val); } - // Apply elementwise operation with D tensors - EDataType result = static_cast(sum); - if(ds_full_dims_host.size() == 0) - { - ; - } - else if(ds_full_dims_host.size() == 1) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert( - ds_full_dims_host[0].mData[g_flat * M_total * N_total + - m_flat * N_total + n_flat])); - } - else if(ds_full_dims_host.size() == 2) - { - cde_elementwise( - result, - ck_tile::type_convert(sum), - ck_tile::type_convert( - ds_full_dims_host[0] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[1] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); - } - else if(ds_full_dims_host.size() == 3) - { - cde_elementwise( - result, - ck_tile::type_convert(sum), - ck_tile::type_convert( - ds_full_dims_host[0] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[1] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[2] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); - } - else if(ds_full_dims_host.size() == 4) - { - cde_elementwise( - result, - ck_tile::type_convert(sum), - ck_tile::type_convert( - ds_full_dims_host[0] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[1] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[2] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat]), - ck_tile::type_convert( - ds_full_dims_host[3] - .mData[g_flat * M_total * N_total + m_flat * N_total + n_flat])); - } - else - { - throw std::runtime_error("Unsupported NumDTensor for reference calculation"); - } + // Compute output offset using strides + const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat); - // Store result - e_full_dims_host_ref.mData[g_flat * M_total * N_total + m_flat * N_total + n_flat] = - static_cast(result); + // Apply elementwise operation with D tensors using compile-time dispatch + EDataType result = static_cast(sum); + ExtractDValues::apply_at_offset( + result, sum, cde_elementwise, ds_full_dims_host, e_offset); + + // Store result using stride-aware indexing + e_full_dims_host_ref.mData[e_offset] = static_cast(result); } }; @@ -123,143 +223,4 @@ void calculate_reference_flat_indexing( make_ParallelTensorFunctor(f_gm, G_total, M_total)(std::thread::hardware_concurrency()); } -template -void calculate_reference_multi_dimensional( - const HostTensor& a_full_dims, - const HostTensor& b_full_dims, - const std::vector>& ds_full_dims_host, - HostTensor& e_full_dims_host_ref, - const std::vector& G_dims, - const std::vector& M_dims, - const std::vector& N_dims, - const std::vector& K_dims, - const std::vector& A_dims, - const std::vector& B_dims, - const std::vector& E_dims, - const CDEElementWise& cde_elementwise) -{ - std::cout << "Calculating reference using multi-dimensional indexing..." << std::endl; - - std::vector g_idx(G_dims.size()); - std::vector m_idx(M_dims.size()); - std::vector n_idx(N_dims.size()); - std::vector k_idx(K_dims.size()); - std::vector a_idx, b_idx, e_idx; - - a_idx.reserve(A_dims.size()); - b_idx.reserve(B_dims.size()); - e_idx.reserve(E_dims.size()); - - for(ck_tile::index_t g_flat = 0; g_flat < calculate_total_elements(G_dims); ++g_flat) - { - ck_tile::index_t temp = g_flat; - for(int i = G_dims.size() - 1; i >= 0; --i) - { - g_idx[i] = temp % G_dims[i]; - temp /= G_dims[i]; - } - - for(ck_tile::index_t m_flat = 0; m_flat < calculate_total_elements(M_dims); ++m_flat) - { - temp = m_flat; - for(int i = M_dims.size() - 1; i >= 0; --i) - { - m_idx[i] = temp % M_dims[i]; - temp /= M_dims[i]; - } - - for(ck_tile::index_t n_flat = 0; n_flat < calculate_total_elements(N_dims); ++n_flat) - { - temp = n_flat; - for(int i = N_dims.size() - 1; i >= 0; --i) - { - n_idx[i] = temp % N_dims[i]; - temp /= N_dims[i]; - } - - AccDataType sum = 0; - - for(ck_tile::index_t k_flat = 0; k_flat < calculate_total_elements(K_dims); - ++k_flat) - { - temp = k_flat; - for(int i = K_dims.size() - 1; i >= 0; --i) - { - k_idx[i] = temp % K_dims[i]; - temp /= K_dims[i]; - } - - a_idx.clear(); - b_idx.clear(); - - a_idx.insert(a_idx.end(), g_idx.begin(), g_idx.end()); - a_idx.insert(a_idx.end(), m_idx.begin(), m_idx.end()); - a_idx.insert(a_idx.end(), k_idx.begin(), k_idx.end()); - - b_idx.insert(b_idx.end(), g_idx.begin(), g_idx.end()); - b_idx.insert(b_idx.end(), n_idx.begin(), n_idx.end()); - b_idx.insert(b_idx.end(), k_idx.begin(), k_idx.end()); - - auto a_val = a_full_dims(a_idx); - auto b_val = b_full_dims(b_idx); - - sum += static_cast(a_val) * static_cast(b_val); - } - - e_idx.clear(); - e_idx.insert(e_idx.end(), g_idx.begin(), g_idx.end()); - e_idx.insert(e_idx.end(), m_idx.begin(), m_idx.end()); - e_idx.insert(e_idx.end(), n_idx.begin(), n_idx.end()); - - EDataType result = static_cast(sum); - if(ds_full_dims_host.size() == 0) - { - ; - } - else if(ds_full_dims_host.size() == 1) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert(ds_full_dims_host[0](e_idx))); - } - else if(ds_full_dims_host.size() == 2) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert(ds_full_dims_host[0](e_idx)), - ck_tile::type_convert(ds_full_dims_host[1](e_idx))); - } - else if(ds_full_dims_host.size() == 3) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert(ds_full_dims_host[0](e_idx)), - ck_tile::type_convert(ds_full_dims_host[1](e_idx)), - ck_tile::type_convert(ds_full_dims_host[2](e_idx))); - } - else if(ds_full_dims_host.size() == 4) - { - cde_elementwise(result, - ck_tile::type_convert(sum), - ck_tile::type_convert(ds_full_dims_host[0](e_idx)), - ck_tile::type_convert(ds_full_dims_host[1](e_idx)), - ck_tile::type_convert(ds_full_dims_host[2](e_idx)), - ck_tile::type_convert(ds_full_dims_host[3](e_idx))); - } - else - { - throw std::runtime_error("Unsupported NumDTensor for reference calculation"); - } - - e_full_dims_host_ref(e_idx) = static_cast(result); - } - } - } -} - } // namespace ck_tile