mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-03-23 16:47:40 +00:00
* Adding CK Tile documentation * Updates based on feedback * Fix tile window API description * Fix remaining images * add documentation about flush_cache and rotating_buffer functionality in ck_tile * Supplement the documentation * light edit of the ck tile conceptual doc * Fixes for ruff check. * Fixes for ruff check 2. * Fixes for ruff check 3. --------- Co-authored-by: Vidyasagar <vanantha@amd.com> Co-authored-by: AviralGoelAMD <aviral.goel@amd.com> Co-authored-by: ThomasNing <thomas.ning@amd.com> Co-authored-by: Vidyasagar Ananthan <vidyasagar.ananthan@amd.com>
490 lines
17 KiB
ReStructuredText
490 lines
17 KiB
ReStructuredText
.. meta::
|
|
:description: CK Tile encoding internals documentation
|
|
:keywords: CK Tile, encoding, tile distribution, GPU programming, compile-time computation
|
|
|
|
.. _ck_tile_encoding_internals:
|
|
|
|
******************
|
|
Encoding Internals
|
|
******************
|
|
|
|
Overview
|
|
========
|
|
|
|
The tile distribution encoding system represents the core mathematical framework that transforms high-level tensor distribution specifications into concrete, optimized GPU kernel implementations. This advanced compile-time machinery bridges the gap between abstract mathematical descriptions and executable coordinate transformations, enabling the Composable Kernel framework to generate highly efficient code for complex tensor operations.
|
|
|
|
At its heart, the encoding system defines how multi-dimensional tensor data is distributed across GPU processing elements through a hierarchical decomposition scheme. By specifying relationships between different coordinate spaces of replication (R), hierarchical (H), partition (P), and yield (Y) dimension, the encoding provides a complete blueprint for data layout and access patterns that can be resolved entirely at compile time. This is the internal mechanism behind :ref:`ck_tile_tile_distribution`. See :ref:`ck_tile_coordinate_systems` for more information about coordinate spaces.
|
|
|
|
..
|
|
Original mermaid diagram (edit here, then run update_diagrams.py)
|
|
|
|
.. mermaid::
|
|
|
|
graph TB
|
|
subgraph "Encoding Components"
|
|
RS["R-space Lengths<br/>Replication dimensions"]
|
|
HS["H-space Lengths<br/>Hierarchical decomposition<br/>[[2,2],[2,2]]"]
|
|
P2RH["P→RH Mappings<br/>Thread to hierarchy<br/>Major/Minor"]
|
|
Y2RH["Y→RH Mappings<br/>Element to hierarchy<br/>Major/Minor"]
|
|
end
|
|
|
|
subgraph "Generated Components"
|
|
ADAPTOR["ps_ys_to_xs_adaptor<br/>Coordinate transformer"]
|
|
DESC["ys_to_d_descriptor<br/>Memory linearizer"]
|
|
ENC["Encoding<br/>Original specification"]
|
|
end
|
|
|
|
subgraph "Transformation Chain"
|
|
T1["Replicate<br/>Transform"]
|
|
T2["Unmerge<br/>Transform"]
|
|
T3["Merge<br/>Transform"]
|
|
end
|
|
|
|
RS --> T1
|
|
HS --> T2
|
|
P2RH --> ADAPTOR
|
|
Y2RH --> ADAPTOR
|
|
|
|
T1 --> T2
|
|
T2 --> T3
|
|
T3 --> ADAPTOR
|
|
|
|
HS --> DESC
|
|
Y2RH --> DESC
|
|
|
|
style RS fill:#fce4ec,stroke:#c2185b,stroke-width:2px
|
|
style HS fill:#e8f5e9,stroke:#388e3c,stroke-width:2px
|
|
style ADAPTOR fill:#e3f2fd,stroke:#1976d2,stroke-width:3px
|
|
style DESC fill:#fff3e0,stroke:#f57c00,stroke-width:3px
|
|
|
|
|
|
|
|
.. image:: diagrams/encoding_internals_1.svg
|
|
:alt: Diagram
|
|
:align: center
|
|
|
|
Encoding Structure
|
|
==================
|
|
|
|
The tile distribution encoding employs a template-based type system that captures the complete specification of tensor distribution patterns at compile time:
|
|
|
|
.. code-block:: cpp
|
|
|
|
template <typename RsLengths_, // Replication dimension lengths
|
|
typename HsLengthss_, // Hierarchical dimension lengths
|
|
typename Ps2RHssMajor_, // P to RH mapping (major)
|
|
typename Ps2RHssMinor_, // P to RH mapping (minor)
|
|
typename Ys2RHsMajor_, // Y to RH mapping (major)
|
|
typename Ys2RHsMinor_> // Y to RH mapping (minor)
|
|
struct tile_distribution_encoding
|
|
{
|
|
// All computations resolved at compile time
|
|
static constexpr index_t NDimX = HsLengthss::size();
|
|
static constexpr index_t NDimP = Ps2RHssMajor::size();
|
|
static constexpr index_t NDimY = Ys2RHsMajor::size();
|
|
static constexpr index_t NDimR = RsLengths::size();
|
|
|
|
// Static member functions for compile-time access
|
|
__host__ __device__ static constexpr auto get_rs_lengths() {
|
|
return RsLengths_{};
|
|
}
|
|
|
|
__host__ __device__ static constexpr auto get_hs_lengthss() {
|
|
return HsLengthss_{};
|
|
}
|
|
|
|
// Nested detail struct performs complex compile-time calculations
|
|
struct detail {
|
|
// Precomputed mappings and transformations
|
|
static constexpr auto get_h_dim_lengths_prefix_sum();
|
|
static constexpr auto get_uniformed_idx_y_to_h();
|
|
// ... compile-time computation ...
|
|
};
|
|
};
|
|
|
|
Key Template Features
|
|
---------------------
|
|
|
|
1. **Template Metaprogramming**: All parameters are types, not values, enabling compile-time optimization
|
|
2. **Constexpr Functions**: Everything is computed at compile time
|
|
3. **Type Aliases**: Clean access to template parameters
|
|
4. **Static Member Functions**: No runtime overhead
|
|
|
|
Parameter Breakdown
|
|
===================
|
|
|
|
R-Dimensions: Replication Specification
|
|
---------------------------------------
|
|
|
|
The ``RsLengths`` parameter defines dimensions that are replicated across processing units, enabling data sharing patterns essential for many tensor operations:
|
|
|
|
.. code-block:: cpp
|
|
|
|
// Example: GEMM with warp-level replication
|
|
using RsLengths = Sequence<NWarpPerBlock, MWarpPerBlock>;
|
|
|
|
// This creates replication pattern:
|
|
// - NWarpPerBlock warps share the same A data
|
|
// - MWarpPerBlock warps share the same B data
|
|
|
|
Replication serves several purposes:
|
|
|
|
- **Data Reuse**: Same input data needed by multiple output computations
|
|
- **Reduction Operations**: Multiple threads collaborate on single result
|
|
- **Memory Efficiency**: Reduces global memory bandwidth requirements
|
|
|
|
H-Dimensions: Hierarchical Decomposition
|
|
----------------------------------------
|
|
|
|
The ``HsLengthss`` parameter represents hierarchical decomposition of tensor dimensions:
|
|
|
|
.. code-block:: cpp
|
|
|
|
// Example: Block-level GEMM decomposition
|
|
using HsLengthss = Tuple<
|
|
Sequence<MRepeat, MWarp, MThread, MVec>, // M-dimension
|
|
Sequence<NRepeat, NWarp, NThread, NVec> // N-dimension
|
|
>;
|
|
|
|
// This creates hierarchy:
|
|
// - MRepeat: iterations per thread in M
|
|
// - MWarp: warps assigned to M
|
|
// - MThread: threads per warp for M
|
|
// - MVec: vector size for M
|
|
|
|
The decomposition enables:
|
|
|
|
- **Memory Coalescing**: Aligning with warp/thread organization
|
|
- **Register Blocking**: Tile sizes that fit in register file
|
|
- **Shared Memory Utilization**: Tiles that exploit data reuse
|
|
|
|
P-Dimensions: Partition Mapping
|
|
-------------------------------
|
|
|
|
The ``Ps2RHssMajor`` and ``Ps2RHssMinor`` parameters define work assignment:
|
|
|
|
.. code-block:: cpp
|
|
|
|
// Example: 2D thread block mapping
|
|
// P0 = warp_id, P1 = lane_id
|
|
using Ps2RHssMajor = Tuple<
|
|
Sequence<1>, // P0 maps to H1 (warp dimension)
|
|
Sequence<2> // P1 maps to H2 (thread dimension)
|
|
>;
|
|
using Ps2RHssMinor = Tuple<
|
|
Sequence<1>, // Use second component of H1
|
|
Sequence<2> // Use third component of H2
|
|
>;
|
|
|
|
The mapping mechanism:
|
|
|
|
- **Major Index**: Which RH-dimension group (0=R, 1-N=H)
|
|
- **Minor Index**: Component within that group
|
|
|
|
Y-Dimensions: Logical View Mapping
|
|
----------------------------------
|
|
|
|
The ``Ys2RHsMajor`` and ``Ys2RHsMinor`` define the user-facing interface:
|
|
|
|
.. code-block:: cpp
|
|
|
|
// Example: 2D tile access pattern
|
|
using Ys2RHsMajor = Sequence<1, 1, 2, 2>; // Y→H mapping
|
|
using Ys2RHsMinor = Sequence<0, 1, 0, 1>; // Component selection
|
|
|
|
// Creates 2x2 logical view:
|
|
// Y[0,0] → H1[0], H2[0]
|
|
// Y[0,1] → H1[1], H2[0]
|
|
// Y[1,0] → H1[0], H2[1]
|
|
// Y[1,1] → H1[1], H2[1]
|
|
|
|
Transformation Pipeline
|
|
=======================
|
|
|
|
The encoding generates a transformation pipeline that converts coordinates using the concepts from :ref:`ck_tile_transforms` and :ref:`ck_tile_adaptors`:
|
|
|
|
..
|
|
Original mermaid diagram (edit here, then run update_diagrams.py)
|
|
|
|
.. mermaid::
|
|
|
|
flowchart LR
|
|
subgraph "Input Coordinates"
|
|
P["P-coordinates<br/>[warp_id, lane_id]"]
|
|
Y["Y-coordinates<br/>[y0, y1, y2, y3]"]
|
|
end
|
|
|
|
subgraph "Transformation Pipeline"
|
|
C1["Combine P+Y"]
|
|
T1["Replicate<br/>Transform<br/>(if R-dims exist)"]
|
|
T2["Unmerge<br/>Transform<br/>(break into H-dims)"]
|
|
T3["Merge<br/>Transform<br/>(combine to X-dims)"]
|
|
end
|
|
|
|
subgraph "Output"
|
|
X["X-coordinates<br/>[x0, x1]<br/>Tensor position"]
|
|
end
|
|
|
|
P --> C1
|
|
Y --> C1
|
|
C1 --> T1
|
|
T1 --> T2
|
|
T2 --> T3
|
|
T3 --> X
|
|
|
|
style P fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
|
|
style Y fill:#fff3e0,stroke:#f57c00,stroke-width:2px
|
|
style X fill:#e8f5e9,stroke:#388e3c,stroke-width:2px
|
|
|
|
|
|
|
|
.. image:: diagrams/encoding_internals_2.svg
|
|
:alt: Diagram
|
|
:align: center
|
|
|
|
Building the Transformation Chain
|
|
---------------------------------
|
|
|
|
.. code-block:: cpp
|
|
|
|
template <typename Encoding>
|
|
__host__ __device__ auto make_ps_ys_to_xs_adaptor(const Encoding& encoding)
|
|
{
|
|
// Step 1: Create individual transforms
|
|
constexpr auto replicate_transform = make_replicate_transform(
|
|
encoding.get_rs_lengths());
|
|
|
|
constexpr auto unmerge_transform = make_unmerge_transform(
|
|
encoding.get_hs_lengthss());
|
|
|
|
constexpr auto merge_transform = make_merge_transform(
|
|
encoding.get_rhs_to_xs_mapping());
|
|
|
|
// Step 2: Chain transforms together
|
|
constexpr auto transform_chain = chain_transforms(
|
|
replicate_transform,
|
|
unmerge_transform,
|
|
merge_transform);
|
|
|
|
// Step 3: Create adaptor with the chain
|
|
return make_tile_adaptor(
|
|
transform_chain,
|
|
encoding.get_lower_dimension_hidden_idss());
|
|
}
|
|
|
|
Transform Implementation Example
|
|
--------------------------------
|
|
|
|
.. code-block:: cpp
|
|
|
|
// Replicate transform implementation
|
|
template <typename Lengths>
|
|
struct replicate_transform
|
|
{
|
|
static constexpr index_t num_of_upper_dimension = size(Lengths{});
|
|
static constexpr index_t num_of_lower_dimension = 2 * num_of_upper_dimension;
|
|
|
|
template <typename UpperIndex>
|
|
__host__ __device__ constexpr auto
|
|
calculate_lower_index(const UpperIndex& idx_upper) const
|
|
{
|
|
// Replicate each coordinate: [a,b] -> [a,b,0,0]
|
|
auto idx_lower = make_zero_multi_index<num_of_lower_dimension>();
|
|
|
|
static_for<0, num_of_upper_dimension, 1>{}([&](auto i) {
|
|
idx_lower(i) = idx_upper[i];
|
|
idx_lower(i + num_of_upper_dimension) = 0;
|
|
});
|
|
|
|
return idx_lower;
|
|
}
|
|
};
|
|
|
|
Y to D Linearization
|
|
====================
|
|
|
|
The Y→D descriptor handles memory layout within each thread, building on :ref:`ck_tile_descriptors` concepts:
|
|
|
|
.. code-block:: cpp
|
|
|
|
template <typename YLengths, typename YStrides>
|
|
struct ys_to_d_descriptor
|
|
{
|
|
static constexpr index_t num_of_dimension = size(YLengths{});
|
|
|
|
// Calculate linear offset from Y coordinates
|
|
template <typename YIndex>
|
|
__host__ __device__ constexpr index_t
|
|
calculate_offset(const YIndex& idx_y) const
|
|
{
|
|
index_t offset = 0;
|
|
|
|
static_for<0, num_of_dimension, 1>{}([&](auto i) {
|
|
offset += idx_y[i] * YStrides{}[i];
|
|
});
|
|
|
|
return offset;
|
|
}
|
|
|
|
// Get element space size (total elements per thread)
|
|
__host__ __device__ static constexpr index_t
|
|
get_element_space_size()
|
|
{
|
|
return reduce_on_sequence(
|
|
YLengths{},
|
|
multiplies{},
|
|
number<1>{});
|
|
}
|
|
};
|
|
|
|
Memory Layout Optimization
|
|
--------------------------
|
|
|
|
.. code-block:: cpp
|
|
|
|
// Optimized layout for vector operations
|
|
template <index_t M, index_t N, index_t VectorSize>
|
|
struct make_ys_to_d_descriptor_for_gemm
|
|
{
|
|
// Layout: [M/VectorSize][N][VectorSize]
|
|
// This ensures vector loads are contiguous in memory
|
|
using type = tile_descriptor<
|
|
Sequence<M/VectorSize, N, VectorSize>,
|
|
Sequence<N * VectorSize, VectorSize, 1>>;
|
|
};
|
|
|
|
Integration in Distributed Tensor
|
|
---------------------------------
|
|
|
|
This shows how the encoding integrates with :ref:`ck_tile_static_distributed_tensor`:
|
|
|
|
.. code-block:: cpp
|
|
|
|
template <typename TileDistribution>
|
|
struct static_distributed_tensor
|
|
{
|
|
using ys_to_d_descriptor = typename TileDistribution::ys_to_d_descriptor;
|
|
|
|
// Thread-local storage
|
|
static constexpr index_t thread_buffer_size =
|
|
ys_to_d_descriptor::get_element_space_size();
|
|
|
|
DataType thread_buffer_[thread_buffer_size];
|
|
|
|
// Access element at Y coordinate
|
|
template <typename YIndex>
|
|
__host__ __device__ DataType& at(const YIndex& idx_y)
|
|
{
|
|
const index_t offset = ys_to_d_descriptor{}.calculate_offset(idx_y);
|
|
return thread_buffer_[offset];
|
|
}
|
|
};
|
|
|
|
Practical Examples
|
|
==================
|
|
|
|
Example 1: Simple 2x2 Distribution
|
|
----------------------------------
|
|
|
|
.. code-block:: cpp
|
|
|
|
// No replication, simple hierarchy
|
|
using SimpleEncoding = tile_distribution_encoding<
|
|
Sequence<>, // rs_lengths: no replication
|
|
Tuple< // hs_lengthss: 2x2 hierarchy
|
|
Sequence<2>,
|
|
Sequence<2>
|
|
>,
|
|
Tuple<Sequence<>, Sequence<>>, // ps_to_rhss_major
|
|
Tuple<Sequence<>, Sequence<>>, // ps_to_rhss_minor
|
|
Sequence<1, 2>, // ys_to_rhs_major
|
|
Sequence<0, 0> // ys_to_rhs_minor
|
|
>;
|
|
|
|
Example 2: GEMM Distribution
|
|
----------------------------
|
|
|
|
.. code-block:: cpp
|
|
|
|
// Complex GEMM distribution with replication
|
|
template<index_t MPerBlock, index_t NPerBlock, index_t KPerBlock,
|
|
index_t MPerWarp, index_t NPerWarp,
|
|
index_t MRepeat, index_t NRepeat>
|
|
using GemmBlockEncoding = tile_distribution_encoding<
|
|
Sequence<>, // No block-level replication
|
|
Tuple< // Hierarchical decomposition
|
|
Sequence<MRepeat, MPerBlock/MPerWarp/MRepeat>, // M
|
|
Sequence<NRepeat, NPerBlock/NPerWarp/NRepeat> // N
|
|
>,
|
|
Tuple< // Warp assignment
|
|
Sequence<1, 2>, // [warp_m, warp_n]
|
|
Sequence<>
|
|
>,
|
|
Tuple<
|
|
Sequence<1, 0>, // Major indices
|
|
Sequence<>
|
|
>,
|
|
Sequence<1, 1, 2, 2>, // Y mapping
|
|
Sequence<0, 1, 0, 1> // Y components
|
|
>;
|
|
|
|
Performance Implications
|
|
========================
|
|
|
|
The encoding system is designed for maximum GPU performance. See :ref:`ck_tile_gpu_basics` for hardware fundamentals.
|
|
|
|
Memory Access Patterns
|
|
----------------------
|
|
|
|
- **Coalescing**: Hierarchical decomposition ensures adjacent threads access adjacent memory
|
|
- **Bank Conflicts**: Careful dimension ordering prevents shared memory conflicts. See :ref:`ck_tile_lds_bank_conflicts` for more information.
|
|
- **Vectorization**: Natural support for vector loads and stores. See :ref:`ck_tile_load_store_traits` for more information.
|
|
|
|
Register Efficiency
|
|
-------------------
|
|
|
|
- **Optimal Allocation**: Y→D linearization minimizes register usage
|
|
- **Spill Avoidance**: Compile-time sizing prevents register spills
|
|
- **Reuse Patterns**: Encoding enables efficient register reuse
|
|
|
|
Compile-Time Optimization
|
|
-------------------------
|
|
|
|
.. code-block:: cpp
|
|
|
|
// All encoding operations resolve at compile time
|
|
template<typename Encoding>
|
|
struct encoding_optimizer {
|
|
// Compute all derived values at compile time
|
|
static constexpr auto total_elements = /* computed */;
|
|
static constexpr auto access_pattern = /* computed */;
|
|
static constexpr auto memory_layout = /* computed */;
|
|
|
|
// Generate optimized code paths
|
|
template<typename Func>
|
|
__device__ void apply_optimized(Func&& f) {
|
|
if constexpr (is_simple_pattern) {
|
|
// Direct access path
|
|
} else if constexpr (is_strided_pattern) {
|
|
// Strided access path
|
|
} else {
|
|
// General access path
|
|
}
|
|
}
|
|
};
|
|
|
|
Summary
|
|
=======
|
|
|
|
The tile distribution encoding system demonstrates compile-time computation:
|
|
|
|
- **Mathematical Foundation**: Complete specification through dimensional relationships
|
|
- **Zero Overhead**: All computations resolve at compile time
|
|
- **Composable Design**: Individual transforms compose into complex mappings
|
|
- **Hardware Alignment**: Natural mapping to GPU execution hierarchy
|
|
- **Performance Focus**: Every design decision optimizes for GPU efficiency
|
|
|
|
The encoding internals show how CK Tile achieves practical performance. By leveraging C++ template metaprogramming and careful architectural design, the framework generates code that rivals hand-optimized implementations while maintaining clarity and composability.
|
|
|
|
For practical examples of how the encoding system is used, see :ref:`ck_tile_thread_mapping`. For coordinate operations that build on these encodings, see :ref:`ck_tile_coordinate_movement`.
|