From 63dcefffc35c812d92c131ee76c23ce204ea5600 Mon Sep 17 00:00:00 2001 From: Damien Lejeune Date: Tue, 10 Feb 2026 13:55:07 +0000 Subject: [PATCH] WIP: v4 tile distribution working --- .../ck_tile/42_mhc/mhc_v4_bf16_benchmark.cpp | 9 +- include/ck_tile/ops/mhc.hpp | 1 + .../ops/mhc/kernel/mhc_kernel_tile_v4.hpp | 303 +++++++++++------- .../ops/mhc/pipeline/mhc_default_policy.hpp | 28 +- .../ck_tile/ops/mhc/pipeline/mhc_problem.hpp | 63 ++++ .../ops/mhc/pipeline/mhc_problem_v4.hpp | 134 ++++++++ 6 files changed, 413 insertions(+), 125 deletions(-) create mode 100644 include/ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp diff --git a/example/ck_tile/42_mhc/mhc_v4_bf16_benchmark.cpp b/example/ck_tile/42_mhc/mhc_v4_bf16_benchmark.cpp index 53eec3a491..22e0b1dae2 100644 --- a/example/ck_tile/42_mhc/mhc_v4_bf16_benchmark.cpp +++ b/example/ck_tile/42_mhc/mhc_v4_bf16_benchmark.cpp @@ -94,13 +94,8 @@ bool run_mhc_benchmark(const ck_tile::ArgParser& arg_parser) d_phi_mem.ToDevice(h_phi.data()); d_output_mem.ToDevice(h_output.data()); - // Define block shape - 64 threads (1 warp) to match BlockGemmShape configuration - // This matches a 16x16 block tile with 1 warp (1x1 warp layout) - using BlockShape = ck_tile::Generic2dBlockShape, - ck_tile::sequence<1, 64>, - ck_tile::sequence<1, 1>>; - - using Problem = ck_tile::MHCProblem; + // Use MHCProblemV4 which automatically derives BlockShape from BlockGemmShape + using Problem = ck_tile::MHCProblemV4; // V4 kernel - optimized with single-pass data loading using KernelV4 = ck_tile::MHCKernelV4; diff --git a/include/ck_tile/ops/mhc.hpp b/include/ck_tile/ops/mhc.hpp index 690beaafae..90a4a21350 100644 --- a/include/ck_tile/ops/mhc.hpp +++ b/include/ck_tile/ops/mhc.hpp @@ -10,6 +10,7 @@ #include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp" #include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp" #include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp" +#include "ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp" #include "ck_tile/ops/mhc/pipeline/mhc_shape.hpp" #include "ck_tile/ops/common/generic_2d_block_shape.hpp" #include "ck_tile/ops/common/load_interleaved_pk_type.hpp" diff --git a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp index a1ea5087a2..f72b9913ab 100644 --- a/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp +++ b/include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp @@ -92,108 +92,158 @@ struct MHCKernelV4 PhiDataType* phi_lds = reinterpret_cast(smem_ptr + kMTile * kKTile * sizeof(XDataType)); - // Shared memory for norm accumulation (one per batch element in tile) - __shared__ ComputeDataType sum_squares_shared[kMTile]; - - // Shared memory for GEMM result accumulation - __shared__ ComputeDataType result_shared[kMTile * kNTile]; - - // Initialize shared norm accumulators and result - for(index_t i = tid; i < kMTile; i += get_block_size()) + // Thread-local norm accumulation (one per batch element in tile) + // Each thread accumulates for the elements it processes + ComputeDataType thread_sum_squares[kMTile]; + for(index_t i = 0; i < kMTile; ++i) { - sum_squares_shared[i] = 0.0f; + thread_sum_squares[i] = 0.0f; } - for(index_t i = tid; i < kMTile * kNTile; i += get_block_size()) - { - result_shared[i] = 0.0f; - } - block_sync_lds(); + + // Create BlockGemm instance and result tile (distributed tensor in registers) + using BlockGemm = BlockGemmASmemBSmemCRegV1; + auto result_tile = BlockGemm::MakeCBlockTile(); + set_tile(result_tile, 0.0f); // Number of K-tile iterations const index_t num_k_tiles = (nC + kKTile - 1) / kKTile; - // Main loop: load tiles, compute norms incrementally, and accumulate GEMM + // Create tensor views for X and Phi + auto x_tensor_full = make_naive_tensor_view( + p_x, make_tuple(batch, nC), make_tuple(nC, 1), number<1>{}, number<1>{}); + + auto x_tensor_padded = pad_tensor_view(x_tensor_full, + make_tuple(number{}, number{}), + sequence{}); + + // Create X DRAM window with tile distribution for vectorized loading + constexpr auto x_load_tile_dist = Problem::MakeXLoadTileDistribution(); + auto x_dram_window = make_tile_window(x_tensor_padded, + make_tuple(number{}, number{}), + {batch_start, 0}, + x_load_tile_dist); + + // Create X LDS tensor view and window + auto x_lds_tensor = make_naive_tensor_view( + x_lds, + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number<1>{}, + number<1>{}); + + auto x_lds_window = + make_tile_window(x_lds_tensor, make_tuple(number{}, number{}), {0, 0}); + + // Create Phi tensor view and window with tile distribution + auto phi_tensor_full = make_naive_tensor_view( + p_phi, make_tuple(output_dim, nC), make_tuple(1, output_dim), number<1>{}, number<1>{}); + + auto phi_tensor_padded = pad_tensor_view(phi_tensor_full, + make_tuple(number{}, number{}), + sequence{}); + + constexpr auto phi_load_tile_dist = Problem::MakePhiLoadTileDistribution(); + auto phi_dram_window = make_tile_window(phi_tensor_padded, + make_tuple(number{}, number{}), + {out_start, 0}, + phi_load_tile_dist); + + // Create Phi LDS tensor view and window + auto phi_lds_tensor = make_naive_tensor_view( + phi_lds, + make_tuple(number{}, number{}), + make_tuple(number{}, number<1>{}), + number<1>{}, + number<1>{}); + + auto phi_lds_window = make_tile_window( + phi_lds_tensor, make_tuple(number{}, number{}), {0, 0}); + + // Main loop: load tiles with vectorization, compute norms, and accumulate GEMM for(index_t k_tile_idx = 0; k_tile_idx < num_k_tiles; ++k_tile_idx) { - const index_t k_start = k_tile_idx * kKTile; - const index_t k_end = min(k_start + kKTile, nC); - const index_t k_len = k_end - k_start; + // Load X tile using vectorized load_tile + auto x_tile = make_static_distributed_tensor(x_load_tile_dist); + load_tile(x_tile, x_dram_window); - // Load X tile from global to LDS and accumulate norm - for(index_t i = tid; i < kMTile * kKTile; i += get_block_size()) - { - const index_t local_m = i / kKTile; - const index_t local_k = i % kKTile; - const index_t global_m = batch_start + local_m; - const index_t global_k = k_start + local_k; + // Accumulate norms from the loaded tile into thread-local storage + constexpr auto x_tile_spans = decltype(x_tile)::get_distributed_spans(); + sweep_tile_span(x_tile_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(x_tile_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + x_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - XDataType x_val = 0; - if(global_m < batch && local_k < k_len) - { - x_val = p_x[global_m * nC + global_k]; + const index_t local_m = tile_idx.at(number<0>{}); + constexpr auto i_j_idx = make_tuple(idx0, idx1); - // Accumulate norm for this batch element using atomics - ComputeDataType x_compute = type_convert(x_val); - ComputeDataType sq = x_compute * x_compute; - atomicAdd(&sum_squares_shared[local_m], sq); - } - x_lds[i] = x_val; - } + ComputeDataType x_val = type_convert(x_tile[i_j_idx]); + thread_sum_squares[local_m] += x_val * x_val; + }); + }); - // Load Phi tile from global to LDS in column-major format (K x N) - // phi is stored in global memory as [nC, output_dim] row-major - // We need to transpose it to [K, N] column-major for BlockGemm - for(index_t i = tid; i < kKTile * kNTile; i += get_block_size()) - { - const index_t local_k = i / kNTile; - const index_t local_n = i % kNTile; - const index_t global_k = k_start + local_k; - const index_t global_n = out_start + local_n; + // Store X tile to LDS + store_tile(x_lds_window, x_tile); - PhiDataType phi_val = 0; - if(local_k < k_len && global_n < output_dim) - { - phi_val = p_phi[global_k * output_dim + global_n]; - } - // Store in column-major: phi_lds[n * kKTile + k] - phi_lds[local_n * kKTile + local_k] = phi_val; - } + // Load Phi tile using vectorized load_tile + auto phi_tile = make_static_distributed_tensor(phi_load_tile_dist); + load_tile(phi_tile, phi_dram_window); + + // Store Phi tile to LDS + store_tile(phi_lds_window, phi_tile); block_sync_lds(); - // Perform manual GEMM: result_acc += x_lds * phi_lds^T - // Distribute work: each thread computes a subset of output elements - // With 64 threads and 16x16 output, each thread handles 4 elements - const index_t total_elements = kMTile * kNTile; - const index_t elements_per_thread = - (total_elements + get_block_size() - 1) / get_block_size(); + // Move windows for next iteration + move_tile_window(x_dram_window, {0, kKTile}); + move_tile_window(phi_dram_window, {0, kKTile}); - for(index_t elem_idx = 0; elem_idx < elements_per_thread; ++elem_idx) - { - const index_t global_elem = tid * elements_per_thread + elem_idx; - if(global_elem < total_elements) - { - const index_t m_idx = global_elem / kNTile; - const index_t n_idx = global_elem % kNTile; - - ComputeDataType acc = 0.0f; - for(index_t k_idx = 0; k_idx < kKTile; ++k_idx) - { - ComputeDataType x_val = - type_convert(x_lds[m_idx * kKTile + k_idx]); - ComputeDataType phi_val = - type_convert(phi_lds[n_idx * kKTile + k_idx]); - acc += x_val * phi_val; - } - // Accumulate to shared memory using atomics - atomicAdd(&result_shared[m_idx * kNTile + n_idx], acc); - } - } + // Perform GEMM using BlockGemm with MFMA: result_tile += x_lds * phi_lds^T + BlockGemm{}(result_tile, x_lds_window, phi_lds_window); block_sync_lds(); } - // Ensure all norm accumulations are complete + // Reduce thread-local norm accumulators using warp shuffle + shared memory + __shared__ ComputeDataType sum_squares_shared[kMTile]; + + // Initialize shared memory + if(tid < kMTile) + { + sum_squares_shared[tid] = 0.0f; + } + block_sync_lds(); + + // Warp-level reduction for each batch element + // Since we have 64 threads (1 warp) and kMTile=16, multiple threads contribute to each + // element + constexpr index_t threads_per_element = + kBlockSize / kMTile; // 64/16 = 4 threads per batch element + + for(index_t local_m = 0; local_m < kMTile; ++local_m) + { + ComputeDataType my_sum = thread_sum_squares[local_m]; + + // Warp shuffle reduction within threads handling this batch element + // Threads [local_m*4, local_m*4+1, local_m*4+2, local_m*4+3] reduce together + const index_t my_group = tid / threads_per_element; + const index_t lane_in_group = tid % threads_per_element; + + if(my_group == local_m) + { +// Reduce within this group of 4 threads using warp shuffle +#pragma unroll + for(index_t offset = threads_per_element / 2; offset > 0; offset /= 2) + { + my_sum += __shfl_down(my_sum, offset); + } + + // First thread in group writes to shared memory + if(lane_in_group == 0) + { + sum_squares_shared[local_m] = my_sum; + } + } + } block_sync_lds(); // Compute inverse norms after all K-tiles processed @@ -213,42 +263,71 @@ struct MHCKernelV4 } } - // Apply normalization, activation, and write output - for(index_t i = tid; i < kMTile * kNTile; i += get_block_size()) - { - const index_t local_m = i / kNTile; - const index_t local_n = i % kNTile; + // Apply normalization and activation in-place on result_tile + constexpr auto result_spans = decltype(result_tile)::get_distributed_spans(); + sweep_tile_span(result_spans[number<0>{}], [&](auto idx0) { + sweep_tile_span(result_spans[number<1>{}], [&](auto idx1) { + const auto tile_idx = get_x_indices_from_distributed_indices( + result_tile.get_tile_distribution(), make_tuple(idx0, idx1)); - const index_t global_m = batch_start + local_m; - const index_t global_n = out_start + local_n; + const index_t local_m = tile_idx.at(number<0>{}); + const index_t local_n = tile_idx.at(number<1>{}); + const index_t global_m = batch_start + local_m; + const index_t global_n = out_start + local_n; - if(global_m >= batch || global_n >= output_dim) - continue; + if(global_m < batch && global_n < output_dim) + { + constexpr auto i_j_idx = make_tuple(idx0, idx1); + const ComputeDataType inv_norm = inv_norms[local_m]; + ComputeDataType value = result_tile[i_j_idx]; - const ComputeDataType inv_norm = inv_norms[local_m]; - ComputeDataType value = result_shared[i]; + // Apply normalization and activation based on output section + if(global_n < n) + { + ComputeDataType activated_value; + Activation{}(activated_value, value); + result_tile(i_j_idx) = alpha_pre * inv_norm * activated_value + bias; + } + else if(global_n < 2 * n) + { + ComputeDataType activated_value; + Activation{}(activated_value, value); + result_tile(i_j_idx) = + alpha_post * inv_norm * 2.0f * activated_value + bias; + } + else + { + result_tile(i_j_idx) = alpha_res * inv_norm * value + bias; + } + } + }); + }); - // Apply normalization and activation based on output section - if(global_n < n) - { - ComputeDataType activated_value; - Activation{}(activated_value, value); - value = alpha_pre * inv_norm * activated_value + bias; - } - else if(global_n < 2 * n) - { - ComputeDataType activated_value; - Activation{}(activated_value, value); - value = alpha_post * inv_norm * 2.0f * activated_value + bias; - } - else - { - value = alpha_res * inv_norm * value + bias; - } + // Cast result to output data type + auto result_output = cast_tile(result_tile); - // Write to global memory - p_output[global_m * output_dim + global_n] = type_convert(value); - } + // Create output tensor view with vectorization for efficient writes + constexpr index_t output_vector_size = 16 / sizeof(YDataType); + + auto output_tensor_full = + make_naive_tensor_view(p_output, + make_tuple(batch, output_dim), + make_tuple(output_dim, 1), + number{}, + number<1>{}); + + // Pad output tensor for boundary handling + auto output_tensor_padded = pad_tensor_view(output_tensor_full, + make_tuple(number{}, number{}), + sequence{}); + + // Create tile window and store using vectorized store_tile + auto output_window = make_tile_window(output_tensor_padded, + make_tuple(number{}, number{}), + {batch_start, out_start}, + result_output.get_tile_distribution()); + + store_tile(output_window, result_output); } }; diff --git a/include/ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp b/include/ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp index 72b440848a..ba91a2c7cb 100644 --- a/include/ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp +++ b/include/ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp @@ -14,17 +14,33 @@ namespace ck_tile { struct MHCDefaultPolicy { - // Provide warp gemm configuration for float data types + // Provide warp gemm configuration for various data types template CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp() { - // For float x float -> float, provide a simple configuration - if constexpr(std::is_same_v && - std::is_same_v && + // For bf16 x bf16 -> float (our case), use MFMA-optimized configuration + if constexpr(std::is_same_v && + std::is_same_v && std::is_same_v) { - // Use a simple warp gemm configuration for float - // This is a basic configuration - can be optimized later + // Use MFMA warp gemm for bf16 inputs with float accumulation + using WG = WarpGemmDispatcher; + return make_tuple(WG{}, 1, 1); // 1 warp in M, 1 warp in N (K warps handled separately) + } + // For float x float -> float, provide a simple configuration + else if constexpr(std::is_same_v && + std::is_same_v && + std::is_same_v) + { using WG = WarpGemmDispatcher, // BlockTile [M, K] + sequence, + sequence<1, VectorSizeA>>; // Vector [1, VectorSizeA] + + // Tile distribution for loading X (input matrix) from global memory + // X is [Batch, nC] row-major, we load kM×kK tiles + // Use BlockShape parameters to ensure consistency + CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution() + { + using namespace ck_tile; + using S = BlockShape; + + // For a 2D tile [M, K], we need to define distribution for both dimensions + // Using BlockShape's Repeat, WarpPerBlock, ThreadPerWarp, Vector parameters + using XTileDistEncoding = + tile_distribution_encoding, // R: No replication + tuple, // H0 (M/Batch dimension) + sequence>, // H1 (K/nC dimension) - using + // N params for K + tuple>, // P→RH major + tuple>, // P→RH minor + sequence<2, 2, 1, 1>, // Y→RH major + sequence<0, 3, 0, 3>>; // Y→RH minor + + return make_static_tile_distribution(XTileDistEncoding{}); + } + + // Tile distribution for loading Phi (weight matrix) from global memory + // Phi is [output_dim, nC] row-major, we load kN×kK tiles + CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution() + { + using namespace ck_tile; + using S = BlockShape; + + // For Phi [N, K] tile + using PhiTileDistEncoding = + tile_distribution_encoding, // R: No replication + tuple, + sequence>, + tuple>, // P→RH major + tuple>, // P→RH minor + sequence<2, 2, 1, 1>, // Y→RH major + sequence<0, 3, 0, 3>>; // Y→RH minor + + return make_static_tile_distribution(PhiTileDistEncoding{}); + } }; } // namespace ck_tile diff --git a/include/ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp b/include/ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp new file mode 100644 index 0000000000..fc95b5d414 --- /dev/null +++ b/include/ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp @@ -0,0 +1,134 @@ +// Copyright (c) Advanced Micro Devices, Inc., or its affiliates. +// SPDX-License-Identifier: MIT + +#pragma once + +#include "ck_tile/core.hpp" +#include "ck_tile/ops/common/tensor_layout.hpp" +#include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp" + +namespace ck_tile { + +// MHC Problem V4: Simplified version that derives BlockShape from BlockGemmShape +// No need to manually specify BlockShape - it's automatically derived +template +struct MHCProblemV4 +{ + using XDataType = remove_cvref_t; + using ComputeDataType = remove_cvref_t; + using YDataType = remove_cvref_t; + + // PhiDataType is the same as XDataType for the weight matrix + using PhiDataType = XDataType; + + // BlockGemm compatibility - map our types to BlockGemm's expected types + using ADataType = XDataType; // Input matrix A + using BDataType = PhiDataType; // Weight matrix B (phi) + using CDataType = ComputeDataType; // Output/accumulator matrix C + + // BlockGemmShape with kM, kN, kK members for BlockGemm + // Using 16x16x16 tiles with 1 warp per block + using BlockGemmShape = TileGemmShape, // BlockTile (M, N, K) + sequence<1, 1, 1>, // BlockWarps (1 warp per block) + sequence<16, 16, 16>>; // WarpTile (16x16x16 MFMA) + + // Vector sizes for loading + static constexpr index_t VectorSizeA = 4; + static constexpr index_t VectorSizeB = 4; + + // Derive BlockShape from BlockGemmShape + // Match V3's approach: use a simple 1×64 configuration for 1 warp + // This ensures proper tile distribution for load_tile/store_tile operations + using BlockShape = + Generic2dBlockShape, // BlockTile [1, 64] - simple layout for 1 warp + sequence<1, 64>, // ThreadPerBlock [1, 64] = 64 threads (1 warp) + sequence<1, 1>>; // Vector [1, 1] - no vectorization in BlockShape + + // Layout types for BlockGemm + using ALayout = ck_tile::tensor_layout::gemm::RowMajor; + using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor; + using CLayout = ck_tile::tensor_layout::gemm::RowMajor; + + // For GEMM pipeline compatibility + using AsDataTypeTuple = tuple; + using BsDataTypeTuple = tuple; + using AsLayoutTuple = tuple; + using BsLayoutTuple = tuple; + + using AElementWise = identity; + using BElementWise = identity; + + static constexpr bool TransposeC = false; + static constexpr bool kPadM = true; + static constexpr bool kPadN = true; + static constexpr bool kPadK = true; + static constexpr bool Preshuffle = false; + + static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave; + static constexpr index_t NumWaveGroups = 1; + + static constexpr index_t VectorLoadSize = 16; + + // kBlockSize derived from BlockShape + static constexpr index_t kBlockSize = BlockShape::BlockSize; + + // Additional traits + static constexpr bool DoubleSmemBuffer = true; + static constexpr bool UseStructuredSparsity = false; + static constexpr bool FixedVectorSize = false; + + struct Traits + { + static constexpr bool UsePersistentKernel = false; + }; + + CK_TILE_HOST static const std::string GetName() { return "MHCProblemV4"; } + + // Tile distribution for loading X (input matrix) from global memory + // X is [Batch, nC] row-major, we load kM×kK tiles (16×16) + // For a 16×16 tile with 64 threads (1 warp): + // M: 1 repeat × 1 warp × 16 threads × 1 vector = 16 + // K: 1 repeat × 1 warp × 4 threads × 4 vector = 16 + // Total threads: 1 warp × (16×4) = 64 threads ✓ + CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution() + { + using namespace ck_tile; + + // H0 (M dimension): [repeat=1, warp=1, thread=16, vector=1] = 16 + // H1 (K dimension): [repeat=1, warp=1, thread=4, vector=4] = 16 + // P→RH: Warp layout = 1 warp in M × 1 warp in K + // Thread layout = 16 threads in M × 4 threads in K = 64 threads/warp + // Y→RH: Access order = M_repeat → M_vector → K_repeat → K_vector (vectorized) + using XTileDistEncoding = tile_distribution_encoding< + sequence<>, // R: No replication + tuple, // H0 (M): repeat=1, warp=1, thread=16, vector=1 + sequence<1, 1, 4, 4>>, // H1 (K): repeat=1, warp=1, thread=4, vector=4 + tuple, sequence<1, 2>>, // P→RH major + tuple, sequence<2, 2>>, // P→RH minor + sequence<1, 1, 2, 2>, // Y→RH major + sequence<0, 3, 0, 3>>; // Y→RH minor + + return make_static_tile_distribution(XTileDistEncoding{}); + } + + // Tile distribution for loading Phi (weight matrix) from global memory + // Phi is [output_dim, nC] row-major, we load kN×kK tiles (16×16) + CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution() + { + using namespace ck_tile; + + // Same distribution as X for 16×16 tiles + using PhiTileDistEncoding = tile_distribution_encoding< + sequence<>, // R: No replication + tuple, // H0 (N): repeat=1, warp=1, thread=16, vector=1 + sequence<1, 1, 4, 4>>, // H1 (K): repeat=1, warp=1, thread=4, vector=4 + tuple, sequence<1, 2>>, // P→RH major + tuple, sequence<2, 2>>, // P→RH minor + sequence<1, 1, 2, 2>, // Y→RH major + sequence<0, 3, 0, 3>>; // Y→RH minor + + return make_static_tile_distribution(PhiTileDistEncoding{}); + } +}; + +} // namespace ck_tile