* chore(copyright) update library wide CMakeLists.txt files copyright header template * Fix build --------- Co-authored-by: Sami Remes <samremes@amd.com>
Tensor Contraction with Multiple A, B, and D Tensors
This example demonstrates a tensor contraction operation with multiple A, B, and D tensors. This extends the basic tensor contraction to handle multiple input tensor pairs and auxiliary tensors simultaneously, enabling complex multi-input tensor network computations to be executed in a single kernel launch.
Mathematical Formulation
This operation performs multiple tensor contractions simultaneously and combines them with auxiliary tensors.
-
Multiple Tensor Contractions: Compute contractions from multiple A and B tensor pairs using Einstein summation notation.
C_{temp0} = \text{einsum}(\text{pattern}_0, A_0, B_0)C_{temp1} = \text{einsum}(\text{pattern}_1, A_1, B_1)\vdotsC_{tempK} = \text{einsum}(\text{pattern}_K, A_K, B_K) -
Combination with Auxiliary Tensors: Apply a user-defined function that combines all contraction results with multiple D tensors.
E = f(C_{temp0}, C_{temp1}, \ldots, C_{tempK}, D_0, D_1, \ldots, D_M)
Each contraction can have different Einstein summation patterns, allowing for complex tensor network computations. The key optimization is that all intermediate tensors are never written to global memory.
Algorithmic Strategy: Multi-Input Contraction with Tensor-to-GEMM Mapping
This kernel extends the tensor contraction algorithm to handle multiple simultaneous contractions.
-
Unified Tensor-to-GEMM Mapping: Each tensor contraction is mapped to a GEMM operation through tensor reshaping:
- Multiple Reshaping Operations: For each contraction pair
(A_i, B_i), the tensors are logically reshaped into 2D matrices based on their Einstein summation pattern. - Coordinated Memory Layout: The reshaping operations are coordinated to enable efficient memory access patterns across all contractions.
- Multiple Reshaping Operations: For each contraction pair
-
Multi-Contraction Tile Computation: Within each thread block:
- Parallel GEMM Execution: Multiple GEMM operations (representing the contractions) are computed simultaneously.
- Complex Address Calculation: Each contraction requires its own address calculation logic for the tensor descriptor interpretation.
- Register Management: Multiple accumulator arrays are maintained for the different contraction results.
-
Tensor Fusion Epilogue: After computing all contractions:
- Multi-Tensor Reshape: The GEMM results are logically reshaped back to their target tensor shapes.
- Load Auxiliary Tensors: Read corresponding elements from all D tensors.
- Apply Fusion Function: Execute the user-defined function
fcombining all results. - Store Final Tensor: Write the combined result to the output tensor.
Source Code Organization
contraction_multi_ABD_xdl.cpp: The main example file. It sets up multiple pairs of tensors for contraction, defines the Einstein summation patterns, sets up auxiliary D tensors, and instantiates theDeviceContractionMultiABDoperation.../../include/ck/tensor_operation/gpu/device/device_contraction_multi_abd.hpp: The device interface for this multi-contraction fusion pattern.- The underlying kernel manages multiple simultaneous tensor contractions with complex tensor descriptor logic and memory access patterns.
Build and Run
Prerequisites
Ensure the Composable Kernel library is built and installed.
cd /path/to/composable_kernel/build
make -j install
Build the Example
cd /path/to/composable_kernel/example/61_contraction_multi_ABD
mkdir build && cd build
cmake \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
..
make -j
Run the Example
# Run the example with default settings
./contraction_multi_ABD_xdl
# Run with verification, data initialization, and timing
./contraction_multi_ABD_xdl 1 2 1
Applications
This kernel is valuable for complex tensor network computations found in advanced scientific and machine learning applications.
- Tensor Network Methods: Computing multiple tensor contractions simultaneously in quantum physics simulations, such as DMRG (Density Matrix Renormalization Group) or PEPS (Projected Entangled Pair States).
- Multi-Modal Tensor Analysis: Processing multiple tensor contractions for different data modalities in machine learning applications.
- Higher-Order Statistics: Computing multiple statistical tensor operations simultaneously, such as different moments or correlation patterns.
- Advanced Neural Network Layers: Implementing complex layers that require multiple tensor operations, such as tensor decomposition layers or high-dimensional convolutions.
- Scientific Computing: Simulating physical systems that require multiple tensor contractions, such as in quantum chemistry or condensed matter physics.
Computational Complexity
The complexity depends on the specific contraction patterns used:
- Multiple Contractions: Each contraction has its own complexity based on tensor dimensions and contraction indices
- Memory Access: Complex patterns due to multiple tensor descriptors and reshaping operations
- Register Pressure: High due to multiple accumulator arrays and intermediate results
- Instruction Diversity: Different contractions may have different computational patterns
Comparison with Single Contraction
| Aspect | Single Contraction | Multi-Contraction |
|---|---|---|
| Input Complexity | Single tensor pair | Multiple tensor pairs |
| Memory Layout | Single reshaping pattern | Multiple coordinated patterns |
| Computation | Single GEMM operation | Multiple parallel GEMMs |
| Fusion Opportunity | Simple epilogue | Complex multi-input epilogue |
| Applications | Basic tensor operations | Complex tensor networks |
This kernel showcases the ability to handle extremely complex tensor network computations efficiently, making it valuable for advanced scientific computing and machine learning research applications.