mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-15 18:42:06 +00:00
[CK-TILE] File-level documentation for static encoding pattern (#2433)
* add file-level comment
* Finished the write-up
---------
Co-authored-by: ThomasNing <thomas.ning@amd.com>
[ROCm/composable_kernel commit: 158ddeb8ce]
This commit is contained in:
@@ -1,6 +1,73 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
/**
|
||||
* @file
|
||||
* We're defining the data access pattern for a 2D window (`XPerTile` by `YPerTile`)
|
||||
for `BlockSize` threads in a thread block.
|
||||
* X dimension is considered contiguous in memory, so a single instruction can access
|
||||
several adjacent and properly aligned elements (vector); the access pattern along X tile
|
||||
dimension is parameterized only by the suggested vector size `VecSize`.
|
||||
* We can't access more than `MaxVecSize = TileElementsPerThread = TileSize / BlockSize` elements
|
||||
with a single memory access, so the actual vector size along the X dimension is
|
||||
`X0 = min(MaxVecSize, VecSize)`.
|
||||
* This leaves `X1 = XPerTile / X0` threads per tile in X dimension.
|
||||
* X1 is also the number of threads per warp in X dimension, that is,
|
||||
X dimension is not split between warps, and each warp accesses X dimension entirely,
|
||||
and there is no iteration in X dimension.
|
||||
* The tuple <X0, X1> defines the X-axis access pattern.
|
||||
This part is common between the 2D distribution patterns.
|
||||
|
||||
* What's different between the different 2D distribution patterns, is the Y axis access pattern.
|
||||
* There are 3 components in this access pattern;
|
||||
* (1) number of Y-axis elements (rows) per warp for a single instruction access,
|
||||
* (2) number of warps per thread block,
|
||||
* (3) number of iterations to cover the entire Y axis.
|
||||
|
||||
* The raked here represents how data is partitioned across different processing granularity.
|
||||
* It represents howe we are going to access the data in thread, warp, or blocked in contiguous
|
||||
region.
|
||||
* From below, the qualifier for 'raked' is the part of warp/thread hierarchy
|
||||
* in the split of Y tile dimension where the iteration happens,
|
||||
* meaning, the iteration can be logically inserted as a tile dimension in 3 ways,
|
||||
* (1) after thread -> thread-raked,
|
||||
* (2) between warp and thread -> warp-raked,
|
||||
* (3) before warp -> block-raked
|
||||
|
||||
* *Thread raked*
|
||||
|
||||
* Y0 is the number of warps, which we can get from the equation `Y0 * WarpSize == BlockSize`
|
||||
* Y1 is the number of rows accessed by a warp within a single iteration,
|
||||
compute it from the equation `Y0 * X1 == WarpSize`
|
||||
* Y2 is the number of iterations to cover the tile,
|
||||
compute it from the equation `Y0 * Y1 * Y2 == YPerTile`
|
||||
|
||||
* *Warp raked*
|
||||
|
||||
* Y0 is the number of warps, we can get it in the same way as for thread-raked pattern,
|
||||
`Y0 * WarpSize == BlockSize`
|
||||
* Y1 is the number of iterations to cover the tile, `Y0 * Y1 * Y2 == YPerTile`.
|
||||
Compute Y2 from the equation below
|
||||
* Y2 is the number of rows accessed by a warp in a single iteration, `Y2 * X1 == WarpSize`
|
||||
|
||||
* *Block raked*
|
||||
|
||||
* Y0 is the number of iterations to cover the tile, `Y0 * Y1 * Y2 == YPerTile`.
|
||||
Compute Y1 and Y2 from the equations below
|
||||
* Y1 is the number of warps, `Y1 * WarpSize == BlockSize`
|
||||
* Y2 is the number of rows accessed by a warp in a single iteration, `Y2 * X1 == WarpSize`
|
||||
|
||||
* In all cases, the tuple <Y0, Y1, Y2> defines the Y-axis access pattern.
|
||||
|
||||
* *Selection*
|
||||
* When we are selecting, Thread-raked is used in element-wise operation because it is the
|
||||
* Thread-major memory order.
|
||||
* Warp-raked is used in matrix multiplication because the vectorization is in warp level.
|
||||
* Block-raked is used mostly for the reduction process, where will reduce the block in global
|
||||
* atomic level.
|
||||
*
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core/arch/arch.hpp"
|
||||
@@ -105,9 +172,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<Y0>,
|
||||
tuple<sequence<Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<0>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
sequence<1, 1>>{}); // -> <Y2, X1>
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -115,9 +182,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<2, 1>>{});
|
||||
sequence<2, 1>>{}); // -> <Y2, X1>
|
||||
}
|
||||
}
|
||||
|
||||
@@ -129,9 +196,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<Y0>,
|
||||
tuple<sequence<X0, X1>, sequence<Y1, Y2>>,
|
||||
tuple<sequence<0>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>,
|
||||
tuple<sequence<0>, sequence<0, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
sequence<1, 1>>{}); // -> <X1, Y2>
|
||||
}
|
||||
else
|
||||
{
|
||||
@@ -139,9 +206,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>,
|
||||
tuple<sequence<0>, sequence<1, 0>>, // -> <Y0>, <Y1, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 2>>{});
|
||||
sequence<1, 2>>{}); // -> <X1, Y2>
|
||||
}
|
||||
}
|
||||
};
|
||||
@@ -182,9 +249,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
sequence<1, 1>>{}); // -> <Y1, X1>
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
@@ -193,9 +260,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>,
|
||||
tuple<sequence<0>, sequence<2, 0>>, // -> <Y0>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 1>>{});
|
||||
sequence<1, 1>>{}); // -> <X1, Y1>
|
||||
}
|
||||
};
|
||||
|
||||
@@ -233,9 +300,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<Y0, Y1, Y2>, sequence<X0, X1>>,
|
||||
tuple<sequence<1>, sequence<1, 2>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<0, 1>>{});
|
||||
sequence<0, 1>>{}); // -> <Y0, X1>
|
||||
}
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffled2DStaticTileDistribution()
|
||||
@@ -244,9 +311,9 @@ struct TileDistributionEncodingPattern2D<BlockSize,
|
||||
tile_distribution_encoding<sequence<1>,
|
||||
tuple<sequence<X0, X1>, sequence<Y0, Y1, Y2>>,
|
||||
tuple<sequence<2>, sequence<2, 1>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>,
|
||||
tuple<sequence<1>, sequence<2, 0>>, // -> <Y1>, <Y2, X0>
|
||||
sequence<1, 2>,
|
||||
sequence<1, 0>>{});
|
||||
sequence<1, 0>>{}); // -> <X1, Y0>
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
Reference in New Issue
Block a user