.. meta::
:description: CK Tile LDS index swapping documentation
:keywords: CK Tile, LDS, index swapping, XOR preshuffle, bank conflicts, GPU optimization
.. _ck_tile_lds_index_swapping:
********************************
Load Data Share Index Swapping
********************************
Overview
========
Local Data Share (LDS) index swapping, also known as XOR preshuffle, is a critical optimization technique in CK Tile for resolving bank conflicts in shared memory. Bank conflicts occur when multiple threads in a warp attempt to access different addresses within the same memory bank simultaneously, causing serialization and performance degradation. CK Tile generalizes the XOR preshuffle technique through a compile-time coordinate transformation system that automatically handles complex access patterns.
The key insight is that transforming the logical 2D coordinates used to access LDS into a different 2D coordinate space ensures that threads accessing data simultaneously access different memory banks. This transformation is implemented through CK Tile's composable transform system, making it both flexible and efficient. See :ref:`ck_tile_transforms` and :ref:`ck_tile_coordinate_systems` for more information about the composable transform system.
Coordinate Transformation Pipeline
==================================
CK Tile performs coordinate transformations to bring LDS access from the original 2D position (M, K dimensions) into transformed (M', K') coordinates:
Step 1: XOR Transform
---------------------
The original K coordinate is split into K0 and K1, where K1 represents the thread vector size along the K dimension (KPack) and K0 is KPerBlock/KPack.
..
Original mermaid diagram (edit here, then run update_diagrams.py)
.. mermaid::
graph TB
subgraph "3D LDS coordinate [K0, M, K1]"
K0["KPerBlock/KPack * MLdsLayer
K0"]
M["MPerBlock/MLdsLayer
M"]
K1["KPack
K1"]
end
subgraph "XOR Transform"
XT["make_xor_transform"]
end
subgraph "Update K0 with XOR transformation"
K01["KPerBlock/KPack * MLdsLayer
K0'"]
M1["MPerBlock/MLdsLayer
M"]
K11["KPack
K1"]
end
K0 --> XT
M --> XT
K1 --> K11
XT --> K01
XT --> M1
style K0 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
style K01 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
style M fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
style M1 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
style K1 fill:#fff3e0,stroke:#f57c00,stroke-width:2px
style K11 fill:#fff3e0,stroke:#f57c00,stroke-width:2px
.. image:: diagrams/lds_index_swapping_1.svg
:alt: Diagram
:align: center
The XOR transformation updates the K0 coordinate using the formula:
.. math::
K0' = K0^{(M \% (KPerBlock / KPack * MLdsLayer))}
This XOR operation redistributes accesses across memory banks by mixing bits from the M and K dimensions.
Step 2: Unmerge Transform
-------------------------
The transformed K0' is split into L and K0'' components, creating an intermediate 4D coordinate space. This is necessary when MLdsLayer > 1, allowing multiple rows to share the same set of memory banks for better utilization with smaller tile sizes.
..
Original mermaid diagram (edit here, then run update_diagrams.py)
.. mermaid::
graph TB
subgraph "3D LDS coordinate [K0', M, K1]"
K0["KPerBlock/KPack * MLdsLayer
K0'"]
M["MPerBlock/MLdsLayer
M"]
K1["KPack
K1"]
end
subgraph "Unmerge into 2 components"
UM["make_unmerge_transform"]
end
subgraph "4D intermediate transformation space"
L["MLdsLayer
L"]
M1["MPerBlock/MLdsLayer
M"]
K01["KPerBlock/KPack
K0''"]
K11["KPack
K1"]
end
K0 --> UM
M --> M1
K1 --> K11
UM --> L
UM --> K01
style K0 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
style L fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
style K01 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
style M fill:#e8f5e9,stroke:#388e3c,stroke-width:2px
style M1 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px
style K1 fill:#fff3e0,stroke:#f57c00,stroke-width:2px
style K11 fill:#fff3e0,stroke:#f57c00,stroke-width:2px
.. image:: diagrams/lds_index_swapping_2.svg
:alt: Diagram
:align: center
The unmerge operation:
.. math::
L = K0' / (KPerBlock/KPack)
K0'' = K0' \% (KPerBlock/KPack)
When MLdsLayer == 1, this simplifies to L=0 and K0''=K0'.
Step 3: Merge Transform
-----------------------
The final step merges the 4D coordinates back into 2D transformed coordinates (M', K').
..
Original mermaid diagram (edit here, then run update_diagrams.py)
.. mermaid::
graph TB
subgraph "4D LDS Coordinates [L, M, K0'', K1]"
L["MLdsLayer
L"]
M1["MPerBlock/MLdsLayer
M"]
K0["KPerBlock/KPack
K0''"]
K1["KPack
K1"]
end
subgraph "Merge into 1 component"
ME0["make_merge_transform"]
end
subgraph "Merge into 1 component"
ME1["make_merge_transform"]
end
subgraph "Transformed 2D coordinates [M', K']"
M11["MPerBlock
M'"]
K01["KPerBlock
K'"]
end
L --> ME0
M1 --> ME0
K0 --> ME1
K1 --> ME1
ME0 --> M11
ME1 --> K01
style K0 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
style K1 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
style K01 fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
style M1 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px
style L fill:#e8f5e9,stroke:#388e3c,stroke-width:2px
style M11 fill:#e8f5e9,stroke:#388e3c,stroke-width:2px
.. image:: diagrams/lds_index_swapping_3.svg
:alt: Diagram
:align: center
C++ Implementation
==================
Here's how the complete transformation chain is implemented in CK Tile using :ref:`ck_tile_descriptors` and transforms:
.. code-block:: cpp
template
struct LdsIndexSwapping {
static constexpr index_t KPerBlock_over_KPack = KPerBlock / KPack;
static constexpr index_t MPerBlock_over_MLdsLayer = MPerBlock / MLdsLayer;
// Step 1: Create base descriptor
using BaseLengths = Sequence<
KPerBlock_over_KPack * MLdsLayer,
MPerBlock_over_MLdsLayer,
KPack
>;
using BaseStrides = Sequence<
KPack,
KPerBlock * MLdsLayer,
1
>;
using BaseDescriptor = TensorDescriptor;
// Step 2: Apply XOR transform
using PermutedDescriptor = decltype(
transform_tensor_descriptor(
BaseDescriptor{},
make_tuple(
make_xor_transform(
Sequence{}
),
make_pass_through_transform(Number{})
),
Sequence<1, 0>{}, // XOR on dims [1,0]
Sequence<2>{} // Pass through dim 2
)
);
// Step 3: Apply unmerge and final transforms
using FinalDescriptor = decltype(
transform_tensor_descriptor(
PermutedDescriptor{},
make_tuple(
make_unmerge_transform(
Sequence{}
),
make_pass_through_transform(Number{}),
make_pass_through_transform(Number{})
),
Sequence<0>{}, // Unmerge dim 0
Sequence<1>{}, // Pass through dim 1
Sequence<2>{}, // Pass through dim 2
Sequence<0, 2>{}, // Output dims from unmerge
Sequence<1>{}, // Output dim 1
Sequence<3>{} // Output dim 3
)
);
};
Practical Usage in GEMM
==========================
Here's how LDS index swapping is used in a real GEMM kernel. See :ref:`ck_tile_gemm_optimization` for more information about GEMM optimization.
.. code-block:: cpp
template
__global__ void gemm_kernel_with_lds_swapping(
const DataType* __restrict__ a_global,
const DataType* __restrict__ b_global,
DataType* __restrict__ c_global,
index_t M, index_t N, index_t K)
{
// Shared memory allocation
__shared__ DataType a_lds[BlockM * BlockK];
__shared__ DataType b_lds[BlockK * BlockN];
// Create LDS descriptor with index swapping
constexpr index_t MLdsLayer = 2; // Typical value for bank conflict avoidance
using ALdsDesc = typename LdsIndexSwapping<
BlockK, KPack, MLdsLayer, BlockM
>::FinalDescriptor;
// Load from global to LDS with swapped indices
auto load_a_to_lds = [&](index_t k_offset) {
// Each thread loads its portion
index_t tid = threadIdx.x;
constexpr index_t NumThreads = blockDim.x;
constexpr index_t ElementsPerThread = (BlockM * BlockK) / NumThreads;
#pragma unroll
for (index_t i = 0; i < ElementsPerThread; ++i) {
index_t linear_idx = tid * ElementsPerThread + i;
// Convert linear index to 2D coordinates
index_t m_idx = linear_idx / BlockK;
index_t k_idx = linear_idx % BlockK;
// Load from global memory
DataType value = a_global[
(blockIdx.y * BlockM + m_idx) * K + k_offset + k_idx
];
// Store to LDS using swapped coordinates
ALdsDesc desc;
index_t lds_offset = desc.calculate_offset({
0, // L component (for this example)
m_idx / MLdsLayer, // M component
k_idx / KPack, // K0 component
k_idx % KPack // K1 component
});
a_lds[lds_offset] = value;
}
};
// Main GEMM computation loop
for (index_t k = 0; k < K; k += BlockK) {
// Load tiles to LDS with index swapping
load_a_to_lds(k);
__syncthreads();
// Compute using swapped LDS layout
// ... (matrix multiplication using transformed coordinates)
}
}
Bank Conflict Analysis
======================
The effectiveness of index swapping can be analyzed by examining access patterns:
.. code-block:: cpp
template
struct BankConflictAnalyzer {
static constexpr index_t NumBanks = 32;
static constexpr index_t BankWidth = 4; // 4 bytes per bank
template
static void analyze_access_pattern() {
// Simulate warp access pattern
index_t bank_access[NumBanks] = {0};
// Each thread in warp accesses one element
for (index_t tid = 0; tid < WarpSize; ++tid) {
// Calculate coordinates for this thread
index_t m_coord = tid / 8; // Example mapping
index_t k_coord = tid % 8;
// Get LDS offset using descriptor
LdsDescriptor desc;
index_t offset = desc.calculate_offset({m_coord, k_coord});
// Determine bank
index_t bank = (offset * sizeof(float) / BankWidth) % NumBanks;
bank_access[bank]++;
}
// Check for conflicts
index_t max_conflict = 0;
for (index_t bank = 0; bank < NumBanks; ++bank) {
max_conflict = max(max_conflict, bank_access[bank]);
}
printf("Max bank conflict: %d-way\n", max_conflict);
}
};
Performance Benefits
====================
LDS index swapping provides several key benefits:
1. **Conflict-Free Access**: Eliminates or significantly reduces bank conflicts
2. **Higher Throughput**: Enables full memory bandwidth utilization
3. **Automatic Optimization**: Transformation parameters can be tuned per architecture
4. **Composability**: Integrates seamlessly with other CK Tile transformations
Advanced Configurations
=======================
Different configurations can be used based on tile sizes and data types:
.. code-block:: cpp
// Configuration for different scenarios
template
struct LdsSwappingConfig {
// Smaller tiles may need different MLdsLayer
static constexpr index_t MLdsLayer =
(TileSize <= 32) ? 1 :
(TileSize <= 64) ? 2 : 4;
// Adjust KPack based on data type
static constexpr index_t KPack =
sizeof(DataType) == 2 ? 8 : // FP16/BF16
sizeof(DataType) == 4 ? 4 : 2; // FP32
// Validate configuration
static_assert(TileSize % (MLdsLayer * KPack) == 0,
"Tile size must be divisible by MLdsLayer * KPack");
};
Integration with Tile Distribution
==================================
LDS index swapping works seamlessly with CK Tile's distribution system. See :ref:`ck_tile_tile_distribution` for more information about CK Tile's distribution system.
.. code-block:: cpp
template
struct DistributedLdsAccess {
using LdsDesc = typename LdsIndexSwapping<...>::FinalDescriptor;
__device__ void load_from_lds(
const float* lds_ptr,
StaticDistributedTensor& reg_tensor)
{
// Each thread loads its distributed portion
auto coord = make_tensor_coordinate(LdsDesc{}, {0, 0, 0, 0});
#pragma unroll
for (index_t i = 0; i < reg_tensor.size(); ++i) {
// Calculate swapped LDS coordinates for this element
auto [m, k] = TileDistribution::get_local_tile_index(i);
// Move to correct position
move_tensor_coordinate(LdsDesc{}, coord, {0, m, k/4, k%4});
// Load with transformed coordinates
reg_tensor[i] = lds_ptr[coord.get_offset()];
}
}
};
Summary
=======
LDS index swapping in CK Tile provides a effective and flexible solution to the bank conflict problem:
- **Generalized XOR Preshuffle**: Extends the basic XOR technique through composable transforms
- **Multi-Step Pipeline**: Coordinates flow through XOR → Unmerge → Merge transformations
- **Automatic Optimization**: Parameters like MLdsLayer adapt to tile sizes and data types
- **Zero Overhead**: All transformations resolve at compile time
- **Seamless Integration**: Works naturally with other CK Tile components
By understanding and utilizing LDS index swapping, kernels can achieve maximum shared memory bandwidth, which is often the limiting factor in GPU kernel performance. The transformation-based approach makes it easy to experiment with different swapping strategies while maintaining code clarity.
For practical examples of how index swapping is used in complete kernels, see :ref:`ck_tile_swizzling_example`. For more on coordinate operations used here, see :ref:`ck_tile_coordinate_movement` and :ref:`ck_tile_tensor_coordinates`.