mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-10 16:28:38 +00:00
[rocm-libraries] ROCm/rocm-libraries#7829 (commit 13af7da)
[ck] Enforce ASCII-only C/C++ sources for hipRTC compatibility (#7829) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Summary CK source files must be compilable via **hipRTC (HIP runtime compilation)**, whose preprocessor does not accept non-ASCII bytes anywhere in a translation unit — **including in comments**. Bytes that are harmless under `hipcc` (em-dashes, smart quotes, multiplication signs, Greek letters, box-drawing glyphs, etc.) cause hipRTC to fail at preprocessing time. These regularly leak in via LLM-assisted authoring or copy/paste from formatted documents and silently break hipRTC paths that are not exercised by the default `hipcc`-based build matrix. This PR (a) cleans every existing violation (53 files) and (b) adds a pre-checkin gate so new violations are rejected before merge. ## File extensions covered Both the cleanup scan and the new Jenkins enforcement stage use the same predicate: ``` *.h *.hpp *.cpp *.h.in *.hpp.in *.cpp.in *.inc *.cl ``` (excluding `*/build/*` and `*/include/rapidjson/*`). This is a strict superset of the existing `Clang Format` stage's predicate — `*.inc` is added so test-fixture include files are also gated. The local pre-commit hook's `c++/inc` type filter covers the same set. ## Why no enforcement today CK is opted out of the rocm-libraries root `.pre-commit-config.yaml`, so the existing `pre-commit` workflow doesn't touch CK. The local CK `.pre-commit-config.yaml` only runs for developers who installed hooks. The **authoritative gate is therefore the new Jenkins stage** in this PR; the local hook is convenience. ## Commit layout (bisect-friendly) 1. `79798aa6261` — **`[ck] Convert reflect/ rendering to ASCII for hipRTC compatibility`** Behavior change, isolated. `TreeFormatter` swaps `├─ / └─ / │ ` for `|- / +- / | ` (3-col width preserved so alignment is unchanged). `conv_description.hpp` swaps `×` for `x` as the dimension separator. `test_conv_description.cpp` expected strings updated in lockstep so the snapshot test stays green. This is the only commit in the series with observable runtime impact. 2. `738fdb0d81c` — **`[ck] Strip non-ASCII bytes from C++ sources for hipRTC compatibility`** Mechanical text cleanup across 53 files. Replacements happen in comments or in `std::cout` strings that are not asserted on by any test. None of the 174 `.inc` files in the tree required edits, but they were in the scan's predicate so the enforcement stage's predicate is a superset of what was scanned. Full replacement table in the commit message. 3. `1d7cd8ba235` — **`[ck] Enforce ASCII-only C/C++ sources for hipRTC compatibility`** - New `projects/composablekernel/script/check_ascii_only.sh` (modeled on `check_copyright_year.sh`). - New entry in `projects/composablekernel/.pre-commit-config.yaml` under the local-hooks block (`types_or: [c++, inc]`). - New `ASCII Only Check` parallel stage in `projects/composablekernel/Jenkinsfile`'s `Static checks` block, mirroring the existing `Clang Format` stage but with `*.inc` added to the find predicate. Always-on, no `RUN_CPPCHECK` gate. The tree is buildable at every commit boundary. Commit 1 leaves 50 known violations; commit 2 leaves 0; commit 3 wires the gate. ## Demo Script output on a synthesized violation: ``` $ printf '// em-dash test \xe2\x80\x94 here\n' > /tmp/bad.cpp $ projects/composablekernel/script/check_ascii_only.sh /tmp/bad.cpp ERROR: /tmp/bad.cpp contains non-ASCII bytes: 1:// em-dash test — here Fix: replace with ASCII (em-dash -> --, smart quotes -> ", arrows -> ->, etc.) $ echo $? 1 ``` Full repo scan after the cleanup commits (note the `-name '*.inc'` clause): ``` $ cd projects/composablekernel && find . -type f \( -name '*.h' -o -name '*.hpp' -o -name '*.cpp' \ -o -name '*.h.in' -o -name '*.hpp.in' -o -name '*.cpp.in' -o -name '*.inc' -o -name '*.cl' \) \ -not -path '*/build/*' -not -path '*/include/rapidjson/*' -print0 \ | xargs -0 -P 8 -n 64 script/check_ascii_only.sh $ echo $? 0 ``` ## Test plan - [ ] Jenkins PR build: confirm new `Static checks -> ASCII Only Check` stage runs green over the full predicate (incl. `*.inc`) and existing `Clang Format` stage is unaffected. - [ ] `test_conv_description` passes against the ASCII tree-formatter output (touched in commit 1). - [ ] Local: `pre-commit run ascii-only-checker --all-files` runs cleanly after installing CK pre-commit hooks via `script/install_precommit.sh`. - [ ] Manually inject a non-ASCII byte in any `.cpp/.hpp/.inc` file, push: confirm Jenkins fails the new stage with a clear error. - [ ] Spot-check a representative subset of touched files under hipRTC compilation to confirm no remaining hipRTC-blocking content (optional, since the static byte check is a sufficient condition for hipRTC preprocessor acceptance on this dimension). 🤖 Generated with [Claude Code](https://claude.com/claude-code)
This commit is contained in:
committed by
assistant-librarian[bot]
parent
4fcd73a98e
commit
96c39b331e
@@ -2,7 +2,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/*
|
||||
* Tutorial: CK Tile Distribution Encoding — A Matrix DRAM Load
|
||||
* Tutorial: CK Tile Distribution Encoding -- A Matrix DRAM Load
|
||||
*
|
||||
* Demonstrates how tile_distribution_encoding maps threads to A-matrix
|
||||
* elements during a DRAM load in the naive GEMM tutorial.
|
||||
@@ -10,7 +10,7 @@
|
||||
* Source: block_gemm_pipeline_agmem_bgmem_creg_policy.hpp
|
||||
* MakeADramTileDistribution(), with fp16, BlockSize=256
|
||||
*
|
||||
* Tile: M=256 × K=32 (matches the naive GEMM's A block tile)
|
||||
* Tile: M=256 x K=32 (matches the naive GEMM's A block tile)
|
||||
* Threads: 256 (4 warps on CDNA, 8 on RDNA)
|
||||
*
|
||||
* Host initialises A with sequential values 0, 1, 2, ... (row-major).
|
||||
@@ -22,7 +22,7 @@
|
||||
* The distribution encoding is hardcoded to match the fp16 derivation
|
||||
* (K1=16/sizeof(fp16)=8), not recomputed from sizeof(int32_t).
|
||||
*
|
||||
* No compute is performed — this is purely about data movement.
|
||||
* No compute is performed -- this is purely about data movement.
|
||||
*
|
||||
* Note: Comments and values assume CDNA (warp_size=64). On RDNA (warp_size=32),
|
||||
* the thread-to-data mapping will differ.
|
||||
@@ -37,89 +37,89 @@ using namespace ck_tile;
|
||||
// ============================================================================
|
||||
// THE GOAL
|
||||
// ============================================================================
|
||||
// Matrix A: M=256 rows × K=32 columns, stored in DRAM (row-major, fp16).
|
||||
// Matrix A: M=256 rows x K=32 columns, stored in DRAM (row-major, fp16).
|
||||
// Load the entire tile into registers using 256 threads (4 warps on CDNA).
|
||||
//
|
||||
// For coalesced memory access with fp16, each lane loads 8 contiguous
|
||||
// K-values (8 × 2 bytes = 16 bytes = 128 bits). Since K=32, we need
|
||||
// K-values (8 x 2 bytes = 16 bytes = 128 bits). Since K=32, we need
|
||||
// 32/8 = 4 lanes to cover one row:
|
||||
//
|
||||
// lane 0: K=0..7 lane 1: K=8..15 lane 2: K=16..23 lane 3: K=24..31
|
||||
// └──────────────── one row of 32 K-columns ──────────────────────────────┘
|
||||
// +---------------- one row of 32 K-columns ------------------------------+
|
||||
//
|
||||
// With warp_size=64, each warp has 64 lanes. 4 lanes per row means
|
||||
// 64/4 = 16 rows per warp. With 4 warps, one pass covers 4×16 = 64 rows.
|
||||
// 64/4 = 16 rows per warp. With 4 warps, one pass covers 4x16 = 64 rows.
|
||||
// To cover all 256 rows, each thread iterates M0 = 256/64 = 4 times.
|
||||
//
|
||||
// Per-thread buffer = 4 iterations × 8 K-values = 32 elements.
|
||||
// Per-thread buffer = 4 iterations x 8 K-values = 32 elements.
|
||||
//
|
||||
// Visually for warp 0 (lanes 0–63):
|
||||
// Visually for warp 0 (lanes 0-63):
|
||||
//
|
||||
// A matrix (256×32) lane_id decomposition
|
||||
// ──────────────── ──────────────────────
|
||||
// A matrix (256x32) lane_id decomposition
|
||||
// ---------------- ----------------------
|
||||
// row 0: [ K=0..7 | 8..15 | 16..23 | 24..31 ]
|
||||
// L0 L1 L2 L3 ← iter 0
|
||||
// L0 L1 L2 L3 <- iter 0
|
||||
// row 1: [ K=0..7 | 8..15 | 16..23 | 24..31 ]
|
||||
// L4 L5 L6 L7
|
||||
// ...
|
||||
// row 15: same pattern, lanes 60–63
|
||||
// ────── stride of 64 rows (4 warps × 16 rows/warp) ──────
|
||||
// row 64: L0..L3 ← iter 1
|
||||
// row 15: same pattern, lanes 60-63
|
||||
// ------ stride of 64 rows (4 warps x 16 rows/warp) ------
|
||||
// row 64: L0..L3 <- iter 1
|
||||
// ...
|
||||
// row 128: L0..L3 ← iter 2
|
||||
// row 128: L0..L3 <- iter 2
|
||||
// ...
|
||||
// row 192: L0..L3 ← iter 3
|
||||
// row 192: L0..L3 <- iter 3
|
||||
//
|
||||
// ============================================================================
|
||||
// THE SOLUTION: tile_distribution_encoding
|
||||
// ============================================================================
|
||||
//
|
||||
// Production code derives (fp16, BlockSize=256, MPerBlock=256, KPerBlock=32):
|
||||
// K1 = 16/sizeof(fp16) = 8 → vector load width (8 values)
|
||||
// K0 = KPerBlock/K1 = 4 → 4 K-chunks per row
|
||||
// M2 = warp_size/K0 = 16 → 16 rows per warp
|
||||
// M1 = BlockSize/warp_size = 4 → 4 warps
|
||||
// M0 = MPerBlock/(M2*M1) = 4 → 4 iterations
|
||||
// K1 = 16/sizeof(fp16) = 8 -> vector load width (8 values)
|
||||
// K0 = KPerBlock/K1 = 4 -> 4 K-chunks per row
|
||||
// M2 = warp_size/K0 = 16 -> 16 rows per warp
|
||||
// M1 = BlockSize/warp_size = 4 -> 4 warps
|
||||
// M0 = MPerBlock/(M2*M1) = 4 -> 4 iterations
|
||||
//
|
||||
// Step 1 — Hierarchical dimensions (Hs): factor each axis.
|
||||
// Step 1 -- Hierarchical dimensions (Hs): factor each axis.
|
||||
//
|
||||
// Hs[0] = sequence<4, 4, 16> → M = 4 × 4 × 16 = 256
|
||||
// Hs[1] = sequence<4, 8> → K = 4 × 8 = 32
|
||||
// Hs[0] = sequence<4, 4, 16> -> M = 4 x 4 x 16 = 256
|
||||
// Hs[1] = sequence<4, 8> -> K = 4 x 8 = 32
|
||||
//
|
||||
// Hs[0] Hs[1]
|
||||
// ┌─────┼─────┐ ┌───┴───┐
|
||||
// +-----+-----+ +---+---+
|
||||
// level 0 level 1 level 2 level 0 level 1
|
||||
// = 4 = 4 = 16 = 4 = 8
|
||||
//
|
||||
// Step 2 — Parallel dimensions (Ps): NDimP=2 (P0=warp_id, P1=lane_id).
|
||||
// Step 2 -- Parallel dimensions (Ps): NDimP=2 (P0=warp_id, P1=lane_id).
|
||||
//
|
||||
// P0 = warp_id → Hs[0][1] = 4 (which warp → which M-group)
|
||||
// P1 = lane_id → Hs[0][2]=16 AND Hs[1][0]=4 (merged, total=64)
|
||||
// P0 = warp_id -> Hs[0][1] = 4 (which warp -> which M-group)
|
||||
// P1 = lane_id -> Hs[0][2]=16 AND Hs[1][0]=4 (merged, total=64)
|
||||
//
|
||||
// The merge transform decomposes lane_id:
|
||||
// row_in_warp = lane_id / 4 (0..15, outer)
|
||||
// k_chunk = lane_id % 4 (0..3, inner → coalesced!)
|
||||
// k_chunk = lane_id % 4 (0..3, inner -> coalesced!)
|
||||
//
|
||||
// Ps_major = tuple<sequence<1>, sequence<1, 2>>
|
||||
// Ps_minor = tuple<sequence<1>, sequence<2, 0>>
|
||||
//
|
||||
// How to read Ps: the tuple has 2 elements → NDimP=2.
|
||||
// How to read Ps: the tuple has 2 elements -> NDimP=2.
|
||||
// First element = P0 = warp_id
|
||||
// Second element = P1 = lane_id
|
||||
//
|
||||
// Ps_major = tuple< seq<1>, seq<1, 2> >
|
||||
// ─P0(warp)─ ─P1(lane)──
|
||||
// -P0(warp)- -P1(lane)--
|
||||
// Ps_minor = tuple< seq<1>, seq<2, 0> >
|
||||
// ─P0(warp)─ ─P1(lane)──
|
||||
// -P0(warp)- -P1(lane)--
|
||||
//
|
||||
// P0: major=<1>, minor=<1> → Hs[0], level 1 → M1=4
|
||||
// P1: major=<1,2>, minor=<2,0> → merged:
|
||||
// Hs[0] level 2 → M2=16 (outer, changes slowly)
|
||||
// Hs[1] level 0 → K0=4 (inner, changes every lane → coalesced!)
|
||||
// Total: 16 × 4 = 64 = warp_size
|
||||
// lane / 4 → row_in_warp (M2), lane % 4 → K-chunk (K0)
|
||||
// P0: major=<1>, minor=<1> -> Hs[0], level 1 -> M1=4
|
||||
// P1: major=<1,2>, minor=<2,0> -> merged:
|
||||
// Hs[0] level 2 -> M2=16 (outer, changes slowly)
|
||||
// Hs[1] level 0 -> K0=4 (inner, changes every lane -> coalesced!)
|
||||
// Total: 16 x 4 = 64 = warp_size
|
||||
// lane / 4 -> row_in_warp (M2), lane % 4 -> K-chunk (K0)
|
||||
//
|
||||
// Step 3 — Yield dimensions (Ys): what each thread owns.
|
||||
// Step 3 -- Yield dimensions (Ys): what each thread owns.
|
||||
//
|
||||
// Y0 = Hs[0][0] = 4 (M-iterations)
|
||||
// Y1 = Hs[1][1] = 8 (vector load width)
|
||||
@@ -127,27 +127,27 @@ using namespace ck_tile;
|
||||
// Ys_major = sequence<1, 2>
|
||||
// Ys_minor = sequence<0, 1>
|
||||
//
|
||||
// How to read Ys: parallel arrays — position i gives Yi.
|
||||
// How to read Ys: parallel arrays -- position i gives Yi.
|
||||
//
|
||||
// Ys_major = seq< 1, 2 > → Y0 is in Hs[0], Y1 is in Hs[1]
|
||||
// Ys_minor = seq< 0, 1 > → Y0 is level 0, Y1 is level 1
|
||||
// ─Y0─ ─Y1─
|
||||
// Ys_major = seq< 1, 2 > -> Y0 is in Hs[0], Y1 is in Hs[1]
|
||||
// Ys_minor = seq< 0, 1 > -> Y0 is level 0, Y1 is level 1
|
||||
// -Y0- -Y1-
|
||||
//
|
||||
// Y0: Hs[0] level 0 → M0=4 (iterations along M)
|
||||
// Y1: Hs[1] level 1 → K1=8 (vector load width)
|
||||
// Buffer size = Y0 × Y1 = 4 × 8 = 32 elements per thread.
|
||||
// Y0: Hs[0] level 0 -> M0=4 (iterations along M)
|
||||
// Y1: Hs[1] level 1 -> K1=8 (vector load width)
|
||||
// Buffer size = Y0 x Y1 = 4 x 8 = 32 elements per thread.
|
||||
//
|
||||
// Step 4 — Replicate: Rs = sequence<1> (trivial, size 1).
|
||||
// Step 4 -- Replicate: Rs = sequence<1> (trivial, size 1).
|
||||
//
|
||||
// Complete tree:
|
||||
//
|
||||
// Hs[0] Hs[1]
|
||||
// ┌─────┼─────┐ ┌───┴───┐
|
||||
// +-----+-----+ +---+---+
|
||||
// [Y0] [P0] [P1] [P1] [Y1]
|
||||
// = 4 = 4 = 16 = 4 = 8
|
||||
// (iter) (warp) (row) (K-chunk) (vec load)
|
||||
//
|
||||
// Buffer size = Y0 × Y1 = 4 × 8 = 32 elements per thread.
|
||||
// Buffer size = Y0 x Y1 = 4 x 8 = 32 elements per thread.
|
||||
//
|
||||
// ============================================================================
|
||||
|
||||
@@ -182,7 +182,7 @@ struct TileDistKernelA
|
||||
|
||||
const auto& buf = tile.get_thread_buffer();
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
constexpr index_t kBufSize = 32; // 4 iterations × 8 K-values
|
||||
constexpr index_t kBufSize = 32; // 4 iterations x 8 K-values
|
||||
|
||||
int32_t local_buf[kBufSize];
|
||||
static_for<0, kBufSize, 1>{}([&](auto i) { local_buf[i] = static_cast<int32_t>(buf[i]); });
|
||||
@@ -232,13 +232,13 @@ struct TileDistKernelA
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Lane 0: row_in_warp=0, k_chunk=0 → rows {0, 64, 128, 192}, K=0..7
|
||||
// Lane 0: row_in_warp=0, k_chunk=0 -> rows {0, 64, 128, 192}, K=0..7
|
||||
print_thread(0);
|
||||
__syncthreads();
|
||||
// Lane 1: k_chunk=1 → same rows, K=8..15 (coalesced with lane 0)
|
||||
// Lane 1: k_chunk=1 -> same rows, K=8..15 (coalesced with lane 0)
|
||||
print_thread(1);
|
||||
__syncthreads();
|
||||
// Lane 4: row_in_warp=1 → rows {1, 65, 129, 193}, K=0..7
|
||||
// Lane 4: row_in_warp=1 -> rows {1, 65, 129, 193}, K=0..7
|
||||
print_thread(4);
|
||||
__syncthreads();
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/*
|
||||
* Tutorial: CK Tile Distribution Encoding — B Matrix DRAM Load
|
||||
* Tutorial: CK Tile Distribution Encoding -- B Matrix DRAM Load
|
||||
*
|
||||
* Demonstrates how tile_distribution_encoding maps threads to B-matrix
|
||||
* elements during a DRAM load in the naive GEMM tutorial.
|
||||
@@ -10,7 +10,7 @@
|
||||
* Source: block_gemm_pipeline_agmem_bgmem_creg_policy.hpp
|
||||
* MakeBDramTileDistribution(), with fp16, BlockSize=256
|
||||
*
|
||||
* Tile: N=128 × K=32 (matches the naive GEMM's B block tile)
|
||||
* Tile: N=128 x K=32 (matches the naive GEMM's B block tile)
|
||||
* Threads: 256 (4 warps on CDNA, 8 on RDNA)
|
||||
*
|
||||
* The B encoding has the SAME structure as the A encoding (Tutorial 1),
|
||||
@@ -18,7 +18,7 @@
|
||||
* count), showing how the same encoding pattern adapts to different
|
||||
* tile sizes.
|
||||
*
|
||||
* No compute is performed — this is purely about data movement.
|
||||
* No compute is performed -- this is purely about data movement.
|
||||
*
|
||||
* Note: int32_t is used instead of fp16 for readable printf output.
|
||||
* The distribution encoding is hardcoded to match the fp16 derivation.
|
||||
@@ -36,21 +36,21 @@ using namespace ck_tile;
|
||||
// ============================================================================
|
||||
// THE GOAL
|
||||
// ============================================================================
|
||||
// Matrix B: N=128 rows × K=32 columns, stored in DRAM (row-major, fp16).
|
||||
// (In GEMM, B is stored as [N, K] — each "row" is one output channel.)
|
||||
// Matrix B: N=128 rows x K=32 columns, stored in DRAM (row-major, fp16).
|
||||
// (In GEMM, B is stored as [N, K] -- each "row" is one output channel.)
|
||||
// Load the entire tile into registers using 256 threads (4 warps on CDNA).
|
||||
//
|
||||
// Same coalescing strategy as the A-matrix (Tutorial 1):
|
||||
// - 4 lanes cover one K-row (4 × 8 = 32 K-values)
|
||||
// - 4 lanes cover one K-row (4 x 8 = 32 K-values)
|
||||
// - Each warp (64 lanes) covers 16 N-rows
|
||||
// - 4 warps cover 64 N-rows per iteration
|
||||
// - N0 = 128/64 = 2 iterations (vs 4 for A's M=256)
|
||||
//
|
||||
// Per-thread buffer = 2 iterations × 8 K-values = 16 elements.
|
||||
// Per-thread buffer = 2 iterations x 8 K-values = 16 elements.
|
||||
//
|
||||
// Compare with Tutorial 1 (A-matrix):
|
||||
// A: M=256, M0=4, buffer=32 | B: N=128, N0=2, buffer=16
|
||||
// Everything else is identical — same K-splitting, same coalescing.
|
||||
// Everything else is identical -- same K-splitting, same coalescing.
|
||||
//
|
||||
// ============================================================================
|
||||
// THE SOLUTION: tile_distribution_encoding
|
||||
@@ -63,47 +63,47 @@ using namespace ck_tile;
|
||||
// N1 = BlockSize/warp_size = 4
|
||||
// N0 = NPerBlock/(N2*N1) = 2
|
||||
//
|
||||
// Step 1 — Hierarchical dimensions (Hs):
|
||||
// Step 1 -- Hierarchical dimensions (Hs):
|
||||
//
|
||||
// Hs[0] = sequence<2, 4, 16> → N = 2 × 4 × 16 = 128
|
||||
// Hs[1] = sequence<4, 8> → K = 4 × 8 = 32
|
||||
// Hs[0] = sequence<2, 4, 16> -> N = 2 x 4 x 16 = 128
|
||||
// Hs[1] = sequence<4, 8> -> K = 4 x 8 = 32
|
||||
//
|
||||
// Hs[0] Hs[1]
|
||||
// ┌─────┼─────┐ ┌───┴───┐
|
||||
// +-----+-----+ +---+---+
|
||||
// [Y0] [P0] [P1] [P1] [Y1]
|
||||
// = 2 = 4 = 16 = 4 = 8
|
||||
// (iter) (warp) (row) (K-chunk) (vec load)
|
||||
//
|
||||
// Step 2 — Parallel dimensions (Ps): NDimP=2 (P0=warp_id, P1=lane_id).
|
||||
// Step 2 -- Parallel dimensions (Ps): NDimP=2 (P0=warp_id, P1=lane_id).
|
||||
//
|
||||
// Ps_major = tuple<sequence<1>, sequence<1, 2>>
|
||||
// Ps_minor = tuple<sequence<1>, sequence<2, 0>>
|
||||
//
|
||||
// How to read Ps: the tuple has 2 elements → NDimP=2.
|
||||
// How to read Ps: the tuple has 2 elements -> NDimP=2.
|
||||
// First element = P0 = warp_id
|
||||
// Second element = P1 = lane_id
|
||||
//
|
||||
// P0: major=<1>, minor=<1> → Hs[0], level 1 → N1=4 (which warp)
|
||||
// P1: major=<1,2>, minor=<2,0> → merged:
|
||||
// Hs[0] level 2 → N2=16 (outer, row within warp)
|
||||
// Hs[1] level 0 → K0=4 (inner, K-chunk → coalesced!)
|
||||
// lane / 4 → row_in_warp, lane % 4 → K-chunk
|
||||
// P0: major=<1>, minor=<1> -> Hs[0], level 1 -> N1=4 (which warp)
|
||||
// P1: major=<1,2>, minor=<2,0> -> merged:
|
||||
// Hs[0] level 2 -> N2=16 (outer, row within warp)
|
||||
// Hs[1] level 0 -> K0=4 (inner, K-chunk -> coalesced!)
|
||||
// lane / 4 -> row_in_warp, lane % 4 -> K-chunk
|
||||
//
|
||||
// Step 3 — Yield dimensions (Ys): what each thread owns.
|
||||
// Step 3 -- Yield dimensions (Ys): what each thread owns.
|
||||
//
|
||||
// Ys_major = sequence<1, 2>
|
||||
// Ys_minor = sequence<0, 1>
|
||||
//
|
||||
// How to read Ys: parallel arrays — position i gives Yi.
|
||||
// How to read Ys: parallel arrays -- position i gives Yi.
|
||||
//
|
||||
// Ys_major = seq< 1, 2 > → Y0 is in Hs[0], Y1 is in Hs[1]
|
||||
// Ys_minor = seq< 0, 1 > → Y0 is level 0, Y1 is level 1
|
||||
// ─Y0─ ─Y1─
|
||||
// Ys_major = seq< 1, 2 > -> Y0 is in Hs[0], Y1 is in Hs[1]
|
||||
// Ys_minor = seq< 0, 1 > -> Y0 is level 0, Y1 is level 1
|
||||
// -Y0- -Y1-
|
||||
//
|
||||
// Y0: Hs[0] level 0 → N0=2 (iterations along N)
|
||||
// Y1: Hs[1] level 1 → K1=8 (vector load width)
|
||||
// Y0: Hs[0] level 0 -> N0=2 (iterations along N)
|
||||
// Y1: Hs[1] level 1 -> K1=8 (vector load width)
|
||||
//
|
||||
// Buffer size = Y0 × Y1 = 2 × 8 = 16 elements per thread.
|
||||
// Buffer size = Y0 x Y1 = 2 x 8 = 16 elements per thread.
|
||||
//
|
||||
// ============================================================================
|
||||
|
||||
@@ -138,7 +138,7 @@ struct TileDistKernelB
|
||||
|
||||
const auto& buf = tile.get_thread_buffer();
|
||||
constexpr index_t warp_size = get_warp_size();
|
||||
constexpr index_t kBufSize = 16; // 2 iterations × 8 K-values
|
||||
constexpr index_t kBufSize = 16; // 2 iterations x 8 K-values
|
||||
|
||||
int32_t local_buf[kBufSize];
|
||||
static_for<0, kBufSize, 1>{}([&](auto i) { local_buf[i] = static_cast<int32_t>(buf[i]); });
|
||||
@@ -187,13 +187,13 @@ struct TileDistKernelB
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Lane 0: row_in_warp=0, k_chunk=0 → rows {0, 64}, K=0..7
|
||||
// Lane 0: row_in_warp=0, k_chunk=0 -> rows {0, 64}, K=0..7
|
||||
print_thread(0);
|
||||
__syncthreads();
|
||||
// Lane 1: k_chunk=1 → same rows, K=8..15
|
||||
// Lane 1: k_chunk=1 -> same rows, K=8..15
|
||||
print_thread(1);
|
||||
__syncthreads();
|
||||
// Lane 4: row_in_warp=1 → rows {1, 65}, K=0..7
|
||||
// Lane 4: row_in_warp=1 -> rows {1, 65}, K=0..7
|
||||
print_thread(4);
|
||||
__syncthreads();
|
||||
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
/*
|
||||
* Tutorial: CK Tile Distribution Encoding — C Matrix Register Layout
|
||||
* Tutorial: CK Tile Distribution Encoding -- C Matrix Register Layout
|
||||
*
|
||||
* Demonstrates how C-matrix elements are distributed across thread registers
|
||||
* after MFMA computation. Unlike A/B (which are DRAM loads), C lives entirely
|
||||
* in registers — the distribution describes which thread holds which output
|
||||
* element of C = A × B.
|
||||
* in registers -- the distribution describes which thread holds which output
|
||||
* element of C = A x B.
|
||||
*
|
||||
* This tutorial shows BOTH:
|
||||
* 1. The warp-level C distribution (from MFMA m32n32k8 output mapping)
|
||||
@@ -17,11 +17,11 @@
|
||||
* The macro CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION (default 1) selects
|
||||
* between the standard and transposed C register layouts.
|
||||
*
|
||||
* Tile: M=256 × N=128 (matches the naive GEMM's C block tile)
|
||||
* Tile: M=256 x N=128 (matches the naive GEMM's C block tile)
|
||||
* Warp config: MWarp=4, NWarp=1
|
||||
* MFMA: m32n32k8 (each warp produces a 32×32 output)
|
||||
* MFMA: m32n32k8 (each warp produces a 32x32 output)
|
||||
*
|
||||
* No actual MFMA compute — we construct a C distributed tensor, fill it
|
||||
* No actual MFMA compute -- we construct a C distributed tensor, fill it
|
||||
* with marker values (thread_id * 1000 + buffer_index), and print per-thread
|
||||
* contents to reveal which buffer slots belong to which thread.
|
||||
*
|
||||
@@ -45,24 +45,24 @@ using namespace ck_tile;
|
||||
// THE GOAL
|
||||
// ============================================================================
|
||||
// After GEMM computation, each thread holds a subset of the C matrix
|
||||
// (M=256 × N=128 = 32768 elements) in its registers. We want to understand
|
||||
// (M=256 x N=128 = 32768 elements) in its registers. We want to understand
|
||||
// exactly which C[m][n] elements each thread owns.
|
||||
//
|
||||
// The mapping has two levels:
|
||||
//
|
||||
// BLOCK LEVEL (256×128 → warps and iterations):
|
||||
// BLOCK LEVEL (256x128 -> warps and iterations):
|
||||
// - 4 warps along M (MWarp=4), 1 warp along N (NWarp=1)
|
||||
// - Each warp covers 32 M-rows × 128 N-cols of the block tile
|
||||
// - Each warp covers 32 M-rows x 128 N-cols of the block tile
|
||||
// - Within each warp: MIterPerWarp=2, NIterPerWarp=4
|
||||
// → 2 × 4 = 8 warp-tile iterations per warp
|
||||
// - Each warp-tile iteration is a 32×32 MFMA output
|
||||
// -> 2 x 4 = 8 warp-tile iterations per warp
|
||||
// - Each warp-tile iteration is a 32x32 MFMA output
|
||||
//
|
||||
// WARP LEVEL (32×32 → threads):
|
||||
// - 64 threads produce 32 × 32 = 1024 C elements
|
||||
// WARP LEVEL (32x32 -> threads):
|
||||
// - 64 threads produce 32 x 32 = 1024 C elements
|
||||
// - Each thread holds 1024/64 = 16 elements
|
||||
// - MFMA m32n32k8 arranges these 16 elements in a specific pattern
|
||||
//
|
||||
// The per-thread register buffer = 8 iterations × 16 elements = 128 floats.
|
||||
// The per-thread register buffer = 8 iterations x 16 elements = 128 floats.
|
||||
//
|
||||
// ============================================================================
|
||||
// THE SOLUTION: Two-Level Distribution
|
||||
@@ -70,105 +70,105 @@ using namespace ck_tile;
|
||||
//
|
||||
// --- WARP-LEVEL C DISTRIBUTION (from MFMA m32n32k8) ---
|
||||
//
|
||||
// For fp16→fp32 MFMA m32n32k8 output (kCM0PerLane=4, kCMLane=2,
|
||||
// For fp16->fp32 MFMA m32n32k8 output (kCM0PerLane=4, kCMLane=2,
|
||||
// kCM1PerLane=4, kCNLane=32):
|
||||
//
|
||||
// STANDARD (CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=0):
|
||||
//
|
||||
// Hs[0] = sequence<4, 2, 4> → M-dim: 4 × 2 × 4 = 32
|
||||
// Hs[1] = sequence<32> → N-dim: 32
|
||||
// Ps_major = tuple<sequence<1, 2>> → lane maps to Hs[0][1] and Hs[1][0]
|
||||
// Hs[0] = sequence<4, 2, 4> -> M-dim: 4 x 2 x 4 = 32
|
||||
// Hs[1] = sequence<32> -> N-dim: 32
|
||||
// Ps_major = tuple<sequence<1, 2>> -> lane maps to Hs[0][1] and Hs[1][0]
|
||||
// Ps_minor = tuple<sequence<1, 0>>
|
||||
//
|
||||
// How to read Ps: the tuple has 1 element → NDimP=1 → P0 = lane_id.
|
||||
// P0: major=<1,2>, minor=<1,0> → merged:
|
||||
// Hs[0] level 1 → kCMLane=2 (outer, M-half)
|
||||
// Hs[1] level 0 → kCNLane=32 (inner, N-col → contiguous!)
|
||||
// lane / 32 → M-half, lane % 32 → N-col
|
||||
// How to read Ps: the tuple has 1 element -> NDimP=1 -> P0 = lane_id.
|
||||
// P0: major=<1,2>, minor=<1,0> -> merged:
|
||||
// Hs[0] level 1 -> kCMLane=2 (outer, M-half)
|
||||
// Hs[1] level 0 -> kCNLane=32 (inner, N-col -> contiguous!)
|
||||
// lane / 32 -> M-half, lane % 32 -> N-col
|
||||
//
|
||||
// Ys_major = sequence<1, 1>
|
||||
// Ys_minor = sequence<0, 2>
|
||||
//
|
||||
// How to read Ys: parallel arrays — position i gives Yi.
|
||||
// How to read Ys: parallel arrays -- position i gives Yi.
|
||||
//
|
||||
// Ys_major = seq< 1, 1 > → Y0 is in Hs[0], Y1 is in Hs[0]
|
||||
// Ys_minor = seq< 0, 2 > → Y0 is level 0, Y1 is level 2
|
||||
// ─Y0─ ─Y1─
|
||||
// Ys_major = seq< 1, 1 > -> Y0 is in Hs[0], Y1 is in Hs[0]
|
||||
// Ys_minor = seq< 0, 2 > -> Y0 is level 0, Y1 is level 2
|
||||
// -Y0- -Y1-
|
||||
//
|
||||
// Y0: Hs[0] level 0 → kCM0PerLane=4 (M outer per lane)
|
||||
// Y1: Hs[0] level 2 → kCM1PerLane=4 (M inner per lane)
|
||||
// Y0: Hs[0] level 0 -> kCM0PerLane=4 (M outer per lane)
|
||||
// Y1: Hs[0] level 2 -> kCM1PerLane=4 (M inner per lane)
|
||||
//
|
||||
// Hs[0] Hs[1]
|
||||
// ┌─────┼─────┐ │
|
||||
// +-----+-----+ |
|
||||
// [Y0] [P0] [Y1] [P0]
|
||||
// = 4 = 2 = 4 = 32
|
||||
// (M outer)(lane) (M inner) (lane → N)
|
||||
// (M outer)(lane) (M inner) (lane -> N)
|
||||
//
|
||||
// Per-thread: Y0 × Y1 = 4 × 4 = 16 elements per warp-tile.
|
||||
// Lane decomposition: lane / 32 → M-half (0..1), lane % 32 → N-col (0..31)
|
||||
// Per-thread: Y0 x Y1 = 4 x 4 = 16 elements per warp-tile.
|
||||
// Lane decomposition: lane / 32 -> M-half (0..1), lane % 32 -> N-col (0..31)
|
||||
//
|
||||
// TRANSPOSED (CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION=1):
|
||||
//
|
||||
// Hs[0] = sequence<32> → First dim: N (swapped!)
|
||||
// Hs[1] = sequence<4, 2, 4> → Second dim: M (swapped!)
|
||||
// Ps_major = tuple<sequence<2, 1>> → lane maps to Hs[1][1] and Hs[0][0]
|
||||
// Hs[0] = sequence<32> -> First dim: N (swapped!)
|
||||
// Hs[1] = sequence<4, 2, 4> -> Second dim: M (swapped!)
|
||||
// Ps_major = tuple<sequence<2, 1>> -> lane maps to Hs[1][1] and Hs[0][0]
|
||||
// Ps_minor = tuple<sequence<1, 0>>
|
||||
//
|
||||
// How to read Ps: tuple has 1 element → NDimP=1 → P0 = lane_id.
|
||||
// P0: major=<2,1>, minor=<1,0> → merged:
|
||||
// Hs[1] level 1 → kCMLane=2 (outer, M-half)
|
||||
// Hs[0] level 0 → kCNLane=32 (inner, N-col → contiguous!)
|
||||
// How to read Ps: tuple has 1 element -> NDimP=1 -> P0 = lane_id.
|
||||
// P0: major=<2,1>, minor=<1,0> -> merged:
|
||||
// Hs[1] level 1 -> kCMLane=2 (outer, M-half)
|
||||
// Hs[0] level 0 -> kCNLane=32 (inner, N-col -> contiguous!)
|
||||
// Same lane decomposition as standard, but dimensions are swapped.
|
||||
//
|
||||
// Ys_major = sequence<2, 2>
|
||||
// Ys_minor = sequence<0, 2>
|
||||
//
|
||||
// How to read Ys:
|
||||
// Ys_major = seq< 2, 2 > → Y0 is in Hs[1], Y1 is in Hs[1]
|
||||
// Ys_minor = seq< 0, 2 > → Y0 is level 0, Y1 is level 2
|
||||
// ─Y0─ ─Y1─
|
||||
// Ys_major = seq< 2, 2 > -> Y0 is in Hs[1], Y1 is in Hs[1]
|
||||
// Ys_minor = seq< 0, 2 > -> Y0 is level 0, Y1 is level 2
|
||||
// -Y0- -Y1-
|
||||
//
|
||||
// Y0: Hs[1] level 0 → kCM0PerLane=4 (M outer per lane)
|
||||
// Y1: Hs[1] level 2 → kCM1PerLane=4 (M inner per lane)
|
||||
// Y0: Hs[1] level 0 -> kCM0PerLane=4 (M outer per lane)
|
||||
// Y1: Hs[1] level 2 -> kCM1PerLane=4 (M inner per lane)
|
||||
// Same 16 elements, but now both Y dims are in Hs[1] (M is second).
|
||||
//
|
||||
// Hs[0] Hs[1]
|
||||
// │ ┌─────┼─────┐
|
||||
// | +-----+-----+
|
||||
// [P0] [Y0] [P0] [Y1]
|
||||
// = 32 = 4 = 2 = 4
|
||||
// (lane → N) (M outer)(lane)(M inner)
|
||||
// (lane -> N) (M outer)(lane)(M inner)
|
||||
//
|
||||
// Same 16 elements per thread, but N is the first dimension in the
|
||||
// distribution — this changes which elements are contiguous in the
|
||||
// distribution -- this changes which elements are contiguous in the
|
||||
// thread buffer, affecting downstream store coalescing.
|
||||
//
|
||||
// --- BLOCK-LEVEL OUTER DISTRIBUTION ---
|
||||
//
|
||||
// MIterPerWarp = MPerBlock / (MWarp × WarpGemm::kM) = 256 / (4 × 32) = 2
|
||||
// NIterPerWarp = NPerBlock / (NWarp × WarpGemm::kN) = 128 / (1 × 32) = 4
|
||||
// MIterPerWarp = MPerBlock / (MWarp x WarpGemm::kM) = 256 / (4 x 32) = 2
|
||||
// NIterPerWarp = NPerBlock / (NWarp x WarpGemm::kN) = 128 / (1 x 32) = 4
|
||||
//
|
||||
// Hs[0] = sequence<2, 4> → M-dim: 2 iters × 4 warps
|
||||
// Hs[1] = sequence<4, 1> → N-dim: 4 iters × 1 warp
|
||||
// Hs[0] = sequence<2, 4> -> M-dim: 2 iters x 4 warps
|
||||
// Hs[1] = sequence<4, 1> -> N-dim: 4 iters x 1 warp
|
||||
// Ps_major = tuple<sequence<1, 2>>
|
||||
// Ps_minor = tuple<sequence<1, 1>>
|
||||
//
|
||||
// How to read Ps: tuple has 1 element → NDimP=1 → P0 = warp_id.
|
||||
// P0: major=<1,2>, minor=<1,1> → merged:
|
||||
// Hs[0] level 1 → MWarp=4 (outer)
|
||||
// Hs[1] level 1 → NWarp=1 (inner, trivial)
|
||||
// Total: 4 × 1 = 4 = number of warps
|
||||
// How to read Ps: tuple has 1 element -> NDimP=1 -> P0 = warp_id.
|
||||
// P0: major=<1,2>, minor=<1,1> -> merged:
|
||||
// Hs[0] level 1 -> MWarp=4 (outer)
|
||||
// Hs[1] level 1 -> NWarp=1 (inner, trivial)
|
||||
// Total: 4 x 1 = 4 = number of warps
|
||||
//
|
||||
// Ys_major = sequence<1, 2>
|
||||
// Ys_minor = sequence<0, 0>
|
||||
//
|
||||
// How to read Ys:
|
||||
// Ys_major = seq< 1, 2 > → Y0 is in Hs[0], Y1 is in Hs[1]
|
||||
// Ys_minor = seq< 0, 0 > → Y0 is level 0, Y1 is level 0
|
||||
// ─Y0─ ─Y1─
|
||||
// Ys_major = seq< 1, 2 > -> Y0 is in Hs[0], Y1 is in Hs[1]
|
||||
// Ys_minor = seq< 0, 0 > -> Y0 is level 0, Y1 is level 0
|
||||
// -Y0- -Y1-
|
||||
//
|
||||
// Y0: Hs[0] level 0 → MIterPerWarp=2
|
||||
// Y1: Hs[1] level 0 → NIterPerWarp=4
|
||||
// Block-level buffer = Y0 × Y1 = 2 × 4 = 8 warp-tile slots.
|
||||
// Y0: Hs[0] level 0 -> MIterPerWarp=2
|
||||
// Y1: Hs[1] level 0 -> NIterPerWarp=4
|
||||
// Block-level buffer = Y0 x Y1 = 2 x 4 = 8 warp-tile slots.
|
||||
//
|
||||
// tile_distribution_encoding<sequence<>,
|
||||
// tuple<sequence<2, 4>, sequence<4, 1>>,
|
||||
@@ -179,7 +179,7 @@ using namespace ck_tile;
|
||||
//
|
||||
// make_embed_tile_distribution_encoding(block_outer, warp_encoding)
|
||||
// embeds the warp encoding inside each (MIter, MWarp, NIter, NWarp) cell.
|
||||
// Total per-thread buffer = 2 × 4 × 16 = 128 elements.
|
||||
// Total per-thread buffer = 2 x 4 x 16 = 128 elements.
|
||||
//
|
||||
// ============================================================================
|
||||
|
||||
@@ -354,10 +354,10 @@ int main()
|
||||
printf("=== CK Tile Distribution Tutorial 3: C-Matrix Register Layout ===\n");
|
||||
printf("=== Matches naive GEMM: MPerBlock=256, NPerBlock=128 ===\n\n");
|
||||
printf("MFMA m32n32k8: each warp produces 32x32 = 1024 elements\n");
|
||||
printf(" 64 threads per warp → 16 elements per thread per warp-tile\n");
|
||||
printf(" MWarp=4, NWarp=1 → 4 warps along M, 1 along N\n");
|
||||
printf(" MIterPerWarp=2, NIterPerWarp=4 → 8 warp-tiles per warp\n");
|
||||
printf(" Total per thread: 8 × 16 = 128 elements\n\n");
|
||||
printf(" 64 threads per warp -> 16 elements per thread per warp-tile\n");
|
||||
printf(" MWarp=4, NWarp=1 -> 4 warps along M, 1 along N\n");
|
||||
printf(" MIterPerWarp=2, NIterPerWarp=4 -> 8 warp-tiles per warp\n");
|
||||
printf(" Total per thread: 8 x 16 = 128 elements\n\n");
|
||||
|
||||
#if CK_TILE_ENABLE_TRANSPOSED_C_DISTRIBUTION
|
||||
printf("Current mode: TRANSPOSED C distribution\n");
|
||||
|
||||
Reference in New Issue
Block a user