mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-17 19:40:04 +00:00
106 lines
6.1 KiB
Markdown
106 lines
6.1 KiB
Markdown
# Tensor Contraction with Multiple A, B, and D Tensors
|
|
|
|
This example demonstrates a **tensor contraction operation with multiple A, B, and D tensors**. This extends the basic tensor contraction to handle multiple input tensor pairs and auxiliary tensors simultaneously, enabling complex multi-input tensor network computations to be executed in a single kernel launch.
|
|
|
|
## Mathematical Formulation
|
|
|
|
This operation performs multiple tensor contractions simultaneously and combines them with auxiliary tensors.
|
|
|
|
1. **Multiple Tensor Contractions**: Compute contractions from multiple A and B tensor pairs using Einstein summation notation.
|
|
$C_{temp0} = \text{einsum}(\text{pattern}_0, A_0, B_0)$
|
|
$C_{temp1} = \text{einsum}(\text{pattern}_1, A_1, B_1)$
|
|
$\vdots$
|
|
$C_{tempK} = \text{einsum}(\text{pattern}_K, A_K, B_K)$
|
|
|
|
2. **Combination with Auxiliary Tensors**: Apply a user-defined function that combines all contraction results with multiple D tensors.
|
|
$E = f(C_{temp0}, C_{temp1}, \ldots, C_{tempK}, D_0, D_1, \ldots, D_M)$
|
|
|
|
Each contraction can have different Einstein summation patterns, allowing for complex tensor network computations. The key optimization is that all intermediate tensors are **never written to global memory**.
|
|
|
|
## Algorithmic Strategy: Multi-Input Contraction with Tensor-to-GEMM Mapping
|
|
|
|
This kernel extends the tensor contraction algorithm to handle multiple simultaneous contractions.
|
|
|
|
1. **Unified Tensor-to-GEMM Mapping**: Each tensor contraction is mapped to a GEMM operation through tensor reshaping:
|
|
- **Multiple Reshaping Operations**: For each contraction pair `(A_i, B_i)`, the tensors are logically reshaped into 2D matrices based on their Einstein summation pattern.
|
|
- **Coordinated Memory Layout**: The reshaping operations are coordinated to enable efficient memory access patterns across all contractions.
|
|
|
|
2. **Multi-Contraction Tile Computation**: Within each thread block:
|
|
- **Parallel GEMM Execution**: Multiple GEMM operations (representing the contractions) are computed simultaneously.
|
|
- **Complex Address Calculation**: Each contraction requires its own address calculation logic for the tensor descriptor interpretation.
|
|
- **Register Management**: Multiple accumulator arrays are maintained for the different contraction results.
|
|
|
|
3. **Tensor Fusion Epilogue**: After computing all contractions:
|
|
- **Multi-Tensor Reshape**: The GEMM results are logically reshaped back to their target tensor shapes.
|
|
- **Load Auxiliary Tensors**: Read corresponding elements from all D tensors.
|
|
- **Apply Fusion Function**: Execute the user-defined function `f` combining all results.
|
|
- **Store Final Tensor**: Write the combined result to the output tensor.
|
|
|
|
## Source Code Organization
|
|
|
|
- [`contraction_multi_ABD_xdl.cpp`](./contraction_multi_ABD_xdl.cpp): The main example file. It sets up multiple pairs of tensors for contraction, defines the Einstein summation patterns, sets up auxiliary D tensors, and instantiates the `DeviceContractionMultiABD` operation.
|
|
- [`../../include/ck/tensor_operation/gpu/device/device_contraction_multi_abd.hpp`](../../include/ck/tensor_operation/gpu/device/device_contraction_multi_abd.hpp): The device interface for this multi-contraction fusion pattern.
|
|
- The underlying kernel manages multiple simultaneous tensor contractions with complex tensor descriptor logic and memory access patterns.
|
|
|
|
## Build and Run
|
|
|
|
### Prerequisites
|
|
Ensure the Composable Kernel library is built and installed.
|
|
```bash
|
|
cd /path/to/composable_kernel/build
|
|
make -j install
|
|
```
|
|
|
|
### Build the Example
|
|
```bash
|
|
cd /path/to/composable_kernel/example/61_contraction_multi_ABD
|
|
mkdir build && cd build
|
|
|
|
cmake \
|
|
-DCMAKE_CXX_COMPILER=/opt/rocm/bin/hipcc \
|
|
-DCMAKE_PREFIX_PATH="/opt/rocm;${CK_INSTALL_PATH}" \
|
|
..
|
|
|
|
make -j
|
|
```
|
|
|
|
### Run the Example
|
|
```bash
|
|
# Run the example with default settings
|
|
./contraction_multi_ABD_xdl
|
|
|
|
# Run with verification, data initialization, and timing
|
|
./contraction_multi_ABD_xdl 1 2 1
|
|
```
|
|
|
|
## Applications
|
|
|
|
This kernel is valuable for complex tensor network computations found in advanced scientific and machine learning applications.
|
|
|
|
- **Tensor Network Methods**: Computing multiple tensor contractions simultaneously in quantum physics simulations, such as DMRG (Density Matrix Renormalization Group) or PEPS (Projected Entangled Pair States).
|
|
- **Multi-Modal Tensor Analysis**: Processing multiple tensor contractions for different data modalities in machine learning applications.
|
|
- **Higher-Order Statistics**: Computing multiple statistical tensor operations simultaneously, such as different moments or correlation patterns.
|
|
- **Advanced Neural Network Layers**: Implementing complex layers that require multiple tensor operations, such as tensor decomposition layers or high-dimensional convolutions.
|
|
- **Scientific Computing**: Simulating physical systems that require multiple tensor contractions, such as in quantum chemistry or condensed matter physics.
|
|
|
|
## Computational Complexity
|
|
|
|
The complexity depends on the specific contraction patterns used:
|
|
|
|
- **Multiple Contractions**: Each contraction has its own complexity based on tensor dimensions and contraction indices
|
|
- **Memory Access**: Complex patterns due to multiple tensor descriptors and reshaping operations
|
|
- **Register Pressure**: High due to multiple accumulator arrays and intermediate results
|
|
- **Instruction Diversity**: Different contractions may have different computational patterns
|
|
|
|
## Comparison with Single Contraction
|
|
|
|
| Aspect | Single Contraction | Multi-Contraction |
|
|
|--------|-------------------|-------------------|
|
|
| **Input Complexity** | Single tensor pair | Multiple tensor pairs |
|
|
| **Memory Layout** | Single reshaping pattern | Multiple coordinated patterns |
|
|
| **Computation** | Single GEMM operation | Multiple parallel GEMMs |
|
|
| **Fusion Opportunity** | Simple epilogue | Complex multi-input epilogue |
|
|
| **Applications** | Basic tensor operations | Complex tensor networks |
|
|
|
|
This kernel showcases the ability to handle extremely complex tensor network computations efficiently, making it valuable for advanced scientific computing and machine learning research applications.
|