From 4027a9257946857813f678dc985592183ab9335a Mon Sep 17 00:00:00 2001 From: Mohsen Saffari Date: Fri, 17 Oct 2025 08:53:03 +0000 Subject: [PATCH] Add stride-aware reference for batched contraction with independent D tensor layouts --- .../reference_batched_contraction.hpp | 65 ++++++++++++++++--- 1 file changed, 57 insertions(+), 8 deletions(-) diff --git a/include/ck_tile/host/reference/reference_batched_contraction.hpp b/include/ck_tile/host/reference/reference_batched_contraction.hpp index 324ff06ef7..1ccff293eb 100644 --- a/include/ck_tile/host/reference/reference_batched_contraction.hpp +++ b/include/ck_tile/host/reference/reference_batched_contraction.hpp @@ -44,14 +44,14 @@ struct ExtractDValues> { template CK_TILE_HOST static void - apply_at_offset(EDataType& result, - AccDataType sum, - const CDEElementWise& cde_elementwise, - const std::array, NumDTensor>& ds_tensors, - std::size_t offset) + apply_at_offsets(EDataType& result, + AccDataType sum, + const CDEElementWise& cde_elementwise, + const std::array, NumDTensor>& ds_tensors, + const std::array& d_offsets) { ApplyCDEElementWise::apply( - result, sum, cde_elementwise, ds_tensors[Is].mData[offset]...); + result, sum, cde_elementwise, ds_tensors[Is].mData[d_offsets[Is]]...); } }; @@ -86,6 +86,13 @@ void compute_reference_batched_contraction( const auto b_strides = b_full_dims.get_strides(); const auto e_strides = e_full_dims_host_ref.get_strides(); + // Extract D tensor strides + std::array, NumDTensor> ds_strides; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + ds_strides[d] = ds_full_dims_host[d].get_strides(); + } + const ck_tile::index_t num_g_dims = G_dims.size(); const ck_tile::index_t num_m_dims = M_dims.size(); const ck_tile::index_t num_n_dims = N_dims.size(); @@ -188,6 +195,41 @@ void compute_reference_batched_contraction( return offset; }; + // Helper to compute D tensor offset (D tensors have same shape as E: [G, M, N]) + auto compute_d_offset = [&](ck_tile::index_t g_flat, + ck_tile::index_t m_flat, + ck_tile::index_t n_flat, + ck_tile::index_t d_idx) -> std::size_t { + std::size_t offset = 0; + const auto& d_strides = ds_strides[d_idx]; + + // Decode G dimensions + ck_tile::index_t temp = g_flat; + for(ck_tile::index_t i = num_g_dims - 1; i >= 0; --i) + { + offset += (temp % G_dims[i]) * d_strides[i]; + temp /= G_dims[i]; + } + + // Decode M dimensions + temp = m_flat; + for(ck_tile::index_t i = num_m_dims - 1; i >= 0; --i) + { + offset += (temp % M_dims[i]) * d_strides[num_g_dims + i]; + temp /= M_dims[i]; + } + + // Decode N dimensions + temp = n_flat; + for(ck_tile::index_t i = num_n_dims - 1; i >= 0; --i) + { + offset += (temp % N_dims[i]) * d_strides[num_g_dims + num_m_dims + i]; + temp /= N_dims[i]; + } + + return offset; + }; + // Parallel computation over G and M dimensions auto f_gm = [&](auto g_flat, auto m_flat) { for(ck_tile::index_t n_flat = 0; n_flat < N_total; ++n_flat) @@ -208,10 +250,17 @@ void compute_reference_batched_contraction( // Compute output offset using strides const std::size_t e_offset = compute_e_offset(g_flat, m_flat, n_flat); + // Compute individual D tensor offsets using their respective strides + std::array d_offsets; + for(ck_tile::index_t d = 0; d < NumDTensor; ++d) + { + d_offsets[d] = compute_d_offset(g_flat, m_flat, n_flat, d); + } + // Apply elementwise operation with D tensors using compile-time dispatch EDataType result = static_cast(sum); - ExtractDValues::apply_at_offset( - result, sum, cde_elementwise, ds_full_dims_host, e_offset); + ExtractDValues::apply_at_offsets( + result, sum, cde_elementwise, ds_full_dims_host, d_offsets); // Store result using stride-aware indexing e_full_dims_host_ref.mData[e_offset] = static_cast(result);