Files
composable_kernel/example/61_contraction_multi_ABD/README.md
2025-10-16 10:13:27 +00:00

106 lines
6.1 KiB
Markdown

# Tensor Contraction with Multiple A, B, and D Tensors
This example demonstrates a **tensor contraction operation with multiple A, B, and D tensors**. This extends the basic tensor contraction to handle multiple input tensor pairs and auxiliary tensors simultaneously, enabling complex multi-input tensor network computations to be executed in a single kernel launch.
## Mathematical Formulation
This operation performs multiple tensor contractions simultaneously and combines them with auxiliary tensors.
1. **Multiple Tensor Contractions**: Compute contractions from multiple A and B tensor pairs using Einstein summation notation.
$C_{temp0} = \text{einsum}(\text{pattern}_0, A_0, B_0)$
$C_{temp1} = \text{einsum}(\text{pattern}_1, A_1, B_1)$
$\vdots$
$C_{tempK} = \text{einsum}(\text{pattern}_K, A_K, B_K)$
2. **Combination with Auxiliary Tensors**: Apply a user-defined function that combines all contraction results with multiple D tensors.
$E = f(C_{temp0}, C_{temp1}, \ldots, C_{tempK}, D_0, D_1, \ldots, D_M)$
Each contraction can have different Einstein summation patterns, allowing for complex tensor network computations. The key optimization is that all intermediate tensors are **never written to global memory**.
## Algorithmic Strategy: Multi-Input Contraction with Tensor-to-GEMM Mapping
This kernel extends the tensor contraction algorithm to handle multiple simultaneous contractions.
1. **Unified Tensor-to-GEMM Mapping**: Each tensor contraction is mapped to a GEMM operation through tensor reshaping:
- **Multiple Reshaping Operations**: For each contraction pair `(A_i, B_i)`, the tensors are logically reshaped into 2D matrices based on their Einstein summation pattern.
- **Coordinated Memory Layout**: The reshaping operations are coordinated to enable efficient memory access patterns across all contractions.
2. **Multi-Contraction Tile Computation**: Within each thread block:
- **Parallel GEMM Execution**: Multiple GEMM operations (representing the contractions) are computed simultaneously.
- **Complex Address Calculation**: Each contraction requires its own address calculation logic for the tensor descriptor interpretation.
- **Register Management**: Multiple accumulator arrays are maintained for the different contraction results.
3. **Tensor Fusion Epilogue**: After computing all contractions:
- **Multi-Tensor Reshape**: The GEMM results are logically reshaped back to their target tensor shapes.
- **Load Auxiliary Tensors**: Read corresponding elements from all D tensors.
- **Apply Fusion Function**: Execute the user-defined function `f` combining all results.
- **Store Final Tensor**: Write the combined result to the output tensor.
## Source Code Organization
- [`contraction_multi_ABD_xdl.cpp`](./contraction_multi_ABD_xdl.cpp): The main example file. It sets up multiple pairs of tensors for contraction, defines the Einstein summation patterns, sets up auxiliary D tensors, and instantiates the `DeviceContractionMultiABD` operation.
- [`../../include/ck/tensor_operation/gpu/device/device_contraction_multi_abd.hpp`](../../include/ck/tensor_operation/gpu/device/device_contraction_multi_abd.hpp): The device interface for this multi-contraction fusion pattern.
- The underlying kernel manages multiple simultaneous tensor contractions with complex tensor descriptor logic and memory access patterns.
## Build and Run
### Prerequisites
Ensure the Composable Kernel library is built and installed.
```bash
cd /path/to/composable_kernel/build
make -j install
```
### Build the Example
```bash
cd /path/to/composable_kernel/example/61_contraction_multi_ABD
mkdir build && cd build
cmake \
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
..
make -j
```
### Run the Example
```bash
# Run the example with default settings
./contraction_multi_ABD_xdl
# Run with verification, data initialization, and timing
./contraction_multi_ABD_xdl 1 2 1
```
## Applications
This kernel is valuable for complex tensor network computations found in advanced scientific and machine learning applications.
- **Tensor Network Methods**: Computing multiple tensor contractions simultaneously in quantum physics simulations, such as DMRG (Density Matrix Renormalization Group) or PEPS (Projected Entangled Pair States).
- **Multi-Modal Tensor Analysis**: Processing multiple tensor contractions for different data modalities in machine learning applications.
- **Higher-Order Statistics**: Computing multiple statistical tensor operations simultaneously, such as different moments or correlation patterns.
- **Advanced Neural Network Layers**: Implementing complex layers that require multiple tensor operations, such as tensor decomposition layers or high-dimensional convolutions.
- **Scientific Computing**: Simulating physical systems that require multiple tensor contractions, such as in quantum chemistry or condensed matter physics.
## Computational Complexity
The complexity depends on the specific contraction patterns used:
- **Multiple Contractions**: Each contraction has its own complexity based on tensor dimensions and contraction indices
- **Memory Access**: Complex patterns due to multiple tensor descriptors and reshaping operations
- **Register Pressure**: High due to multiple accumulator arrays and intermediate results
- **Instruction Diversity**: Different contractions may have different computational patterns
## Comparison with Single Contraction
| Aspect | Single Contraction | Multi-Contraction |
|--------|-------------------|-------------------|
| **Input Complexity** | Single tensor pair | Multiple tensor pairs |
| **Memory Layout** | Single reshaping pattern | Multiple coordinated patterns |
| **Computation** | Single GEMM operation | Multiple parallel GEMMs |
| **Fusion Opportunity** | Simple epilogue | Complex multi-input epilogue |
| **Applications** | Basic tensor operations | Complex tensor networks |
This kernel showcases the ability to handle extremely complex tensor network computations efficiently, making it valuable for advanced scientific computing and machine learning research applications.