mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-06-10 16:28:38 +00:00
[CK Tile] Support multi-vector reads in static encoding patterns (#7528) ## Motivation The thread-raked / warp-raked / block-raked static tile distribution patterns in `ck_tile` silently produce wrong results when the contiguous tile dimension is larger than `warp_size * vector_size`, because the encoding has no per-thread iteration dimension along X. Concretely, with `M_Tile=N_Tile=128`, `VectorSize{A,B,C}=1` in `ConvConfigComputeV3`, the grouped convolution backward-weight example reports about 50 percent wrong values, with errors starting exactly at the `X0*X1 = 64` boundary. The second pass over the contiguous dim is never performed. This PR extends the encoding so multi-vector reads in the contiguous tile dimension are supported, while keeping every existing call site bit-for-bit identical. ## Technical Details Three files changed. ### 1. `include/ck_tile/core/algorithm/static_encoding_pattern.hpp` Add a per-thread X iteration dimension in all three raked specializations: - `X0 = min(warp_size, XPerTile / X1)` — threads in X dim - `X1 = min(LargestVec, VecSize)` — vector size per access - `X2 = XPerTile / (X0 * X1)` — number of X-iters per thread (new) `X2` is gated with `if constexpr (X2 == 1) { old } else { new }` in both `make_2d_static_tile_distribution()` and `make_shuffled_2d_static_tile_distribution()`. The new encoding places `X2` in the middle of the Ys iteration list, which preserves reverse symmetry between the regular `<..., X2, X1>` and shuffled `<X1, X2, ...>` encodings. Patterns updated: `thread_raked`, `warp_raked`, `block_raked`. ### 2. `include/ck_tile/core/tensor/transpose_tile.hpp` Added a parallel `else if constexpr (... && NDimY == 3 && ...)` branch alongside the existing `NDimY == 2` branch. The original branch is byte-for-byte unchanged. Both branches dispatch to the same `transpose_tile2d_impl_in_thread`, whose body has always been NDimY-generic (iterates with `static_for<0, NDimY, 1>` and `number<NDimY>{}`). ### 3. `experimental/grouped_convolution_tile_instances/generate_instances.py` Removed the two now-obsolete skip guards in `parse_bwd_weight_instances` and `parse_bwd_data_instances`: ```python if m_per_block > (warp_size * a_scalar_per_vector) or n_per_block > (warp_size * b_scalar_per_vector): print(f"Skipping instance {instance_id} with multiple warps per continous tile dim since it's not supported yet.") continue ``` Other unrelated skips (V5 / V6 / ASYNC_V4 pipeline gating, irregular-load shapes, scalar-per-vector > tile size) are kept untouched. ### Compatibility Strict. Every existing caller has `X2 == 1` and therefore hits the original encoding path verbatim. No upstream config or pipeline behavior changes. ## Test Plan The grouped convolution example is the natural exerciser since `GroupedConvUniversalPipelineAgBgCrPolicy` selects `thread_raked` for both A and B tiles, and all three conv directions share the same `ConvConfigComputeV3`. For each test below we ran: ``` ./build/bin/tile_example_grouped_conv_bwd_weight [-prec={fp16,bf16}] ./build/bin/tile_example_grouped_conv_fwd [-prec={fp16,bf16}] ./build/bin/tile_example_grouped_conv_bwd_data [-prec={fp16,bf16}] ``` with `ConvConfigComputeV3` tile/vector parameters tweaked to cover both code paths: | Test | M / N / K | VecA/B/C | A path | B path | dtype | |------|-------------|----------|------------|----------------|-------------| | T1 | 16/64/32 | 4/8/4 | old (X2=1) | old (X2=1) | fp16 | | T2 | 128/128/64 | 2/2/2 | old (X2=1) | old (X2=1) | fp16 | | T3 | 256/256/64 | 1/1/1 | old (X2=1) | new (X2=4) | fp16 | | T5 | 256/256/64 | 1/1/1 | old (X2=1) | new (X2=4) | fp16 (3 dir)| | T4b | 128/128/128 | 1/1/1 | new (X2=2) | new (X2=2) | fp16 + bf16 (3 dir) | A larger T4a (256/256/128) was attempted to stress both A and B with X2>1 on bigger tiles but was blocked by the gfx942 hardware LDS cap (128 KB > 64 KB limit), independent of this PR. For the generator change we ran: ``` python3 generate_instances.py --mode profiler --direction all ``` and verified `Skipping instance ... with multiple warps per continous tile dim` no longer appears (count went from non-zero to 0); other skip categories are unchanged. `clang-format-18` was applied to both modified `.hpp` files (matches the repo's `.clang-format`). ## Test Result - T1 and T2 (compat-strict, every X2 is 1, old code path): `correct`. Confirms existing callers are unaffected. - T3 (X2=4 on B only): `correct`. First true exercise of the new NDimY=3 encoding + transpose branch. - T5 (T3 across `fwd` + `bwd_data` + `bwd_weight`, fp16): all 3 `correct`. - T4b (X2>1 on both A and B, fp16 + bf16, all 3 directions): all 6 runs `correct`. - Generator: 0 `multiple warps per continous tile dim` skips remaining; other skips unchanged. Sample run output (T4b, bf16, bwd_data): ``` shape: tile_gemm_shape_128x128x128x4_1x4x1_16x16x32 pipeline: pipeline_AgBgCrCompV3_128x128x128_256_1x1x1_1x4_1x1x1_..._DoubleSmemBuffer_0 Vector size A: 1, Vector size B: 1, Vector size C: 1 0.934907 ms, 8.34683 TFlops, 34.3178 GB/s Relative error threshold: 0.00390625 Absolute error threshold: 0.25 The CPU verification result is: correct ``` ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests. --------- Co-authored-by: Cursor <cursoragent@cursor.com>