WIP: v4 tile distribution working

This commit is contained in:
Damien Lejeune
2026-02-10 13:55:07 +00:00
parent 7c728adb57
commit 63dcefffc3
6 changed files with 413 additions and 125 deletions

View File

@@ -10,6 +10,7 @@
#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"

View File

@@ -92,108 +92,158 @@ struct MHCKernelV4
PhiDataType* phi_lds =
reinterpret_cast<PhiDataType*>(smem_ptr + kMTile * kKTile * sizeof(XDataType));
// Shared memory for norm accumulation (one per batch element in tile)
__shared__ ComputeDataType sum_squares_shared[kMTile];
// Shared memory for GEMM result accumulation
__shared__ ComputeDataType result_shared[kMTile * kNTile];
// Initialize shared norm accumulators and result
for(index_t i = tid; i < kMTile; i += get_block_size())
// Thread-local norm accumulation (one per batch element in tile)
// Each thread accumulates for the elements it processes
ComputeDataType thread_sum_squares[kMTile];
for(index_t i = 0; i < kMTile; ++i)
{
sum_squares_shared[i] = 0.0f;
thread_sum_squares[i] = 0.0f;
}
for(index_t i = tid; i < kMTile * kNTile; i += get_block_size())
{
result_shared[i] = 0.0f;
}
block_sync_lds();
// Create BlockGemm instance and result tile (distributed tensor in registers)
using BlockGemm = BlockGemmASmemBSmemCRegV1<Problem, Policy>;
auto result_tile = BlockGemm::MakeCBlockTile();
set_tile(result_tile, 0.0f);
// Number of K-tile iterations
const index_t num_k_tiles = (nC + kKTile - 1) / kKTile;
// Main loop: load tiles, compute norms incrementally, and accumulate GEMM
// Create tensor views for X and Phi
auto x_tensor_full = make_naive_tensor_view<address_space_enum::global>(
p_x, make_tuple(batch, nC), make_tuple(nC, 1), number<1>{}, number<1>{});
auto x_tensor_padded = pad_tensor_view(x_tensor_full,
make_tuple(number<kMTile>{}, number<kKTile>{}),
sequence<false, Problem::kPadK>{});
// Create X DRAM window with tile distribution for vectorized loading
constexpr auto x_load_tile_dist = Problem::MakeXLoadTileDistribution();
auto x_dram_window = make_tile_window(x_tensor_padded,
make_tuple(number<kMTile>{}, number<kKTile>{}),
{batch_start, 0},
x_load_tile_dist);
// Create X LDS tensor view and window
auto x_lds_tensor = make_naive_tensor_view<address_space_enum::lds>(
x_lds,
make_tuple(number<kMTile>{}, number<kKTile>{}),
make_tuple(number<kKTile>{}, number<1>{}),
number<1>{},
number<1>{});
auto x_lds_window =
make_tile_window(x_lds_tensor, make_tuple(number<kMTile>{}, number<kKTile>{}), {0, 0});
// Create Phi tensor view and window with tile distribution
auto phi_tensor_full = make_naive_tensor_view<address_space_enum::global>(
p_phi, make_tuple(output_dim, nC), make_tuple(1, output_dim), number<1>{}, number<1>{});
auto phi_tensor_padded = pad_tensor_view(phi_tensor_full,
make_tuple(number<kNTile>{}, number<kKTile>{}),
sequence<false, Problem::kPadK>{});
constexpr auto phi_load_tile_dist = Problem::MakePhiLoadTileDistribution();
auto phi_dram_window = make_tile_window(phi_tensor_padded,
make_tuple(number<kNTile>{}, number<kKTile>{}),
{out_start, 0},
phi_load_tile_dist);
// Create Phi LDS tensor view and window
auto phi_lds_tensor = make_naive_tensor_view<address_space_enum::lds>(
phi_lds,
make_tuple(number<kNTile>{}, number<kKTile>{}),
make_tuple(number<kKTile>{}, number<1>{}),
number<1>{},
number<1>{});
auto phi_lds_window = make_tile_window(
phi_lds_tensor, make_tuple(number<kNTile>{}, number<kKTile>{}), {0, 0});
// Main loop: load tiles with vectorization, compute norms, and accumulate GEMM
for(index_t k_tile_idx = 0; k_tile_idx < num_k_tiles; ++k_tile_idx)
{
const index_t k_start = k_tile_idx * kKTile;
const index_t k_end = min(k_start + kKTile, nC);
const index_t k_len = k_end - k_start;
// Load X tile using vectorized load_tile
auto x_tile = make_static_distributed_tensor<XDataType>(x_load_tile_dist);
load_tile(x_tile, x_dram_window);
// Load X tile from global to LDS and accumulate norm
for(index_t i = tid; i < kMTile * kKTile; i += get_block_size())
{
const index_t local_m = i / kKTile;
const index_t local_k = i % kKTile;
const index_t global_m = batch_start + local_m;
const index_t global_k = k_start + local_k;
// Accumulate norms from the loaded tile into thread-local storage
constexpr auto x_tile_spans = decltype(x_tile)::get_distributed_spans();
sweep_tile_span(x_tile_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(x_tile_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
x_tile.get_tile_distribution(), make_tuple(idx0, idx1));
XDataType x_val = 0;
if(global_m < batch && local_k < k_len)
{
x_val = p_x[global_m * nC + global_k];
const index_t local_m = tile_idx.at(number<0>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// Accumulate norm for this batch element using atomics
ComputeDataType x_compute = type_convert<ComputeDataType>(x_val);
ComputeDataType sq = x_compute * x_compute;
atomicAdd(&sum_squares_shared[local_m], sq);
}
x_lds[i] = x_val;
}
ComputeDataType x_val = type_convert<ComputeDataType>(x_tile[i_j_idx]);
thread_sum_squares[local_m] += x_val * x_val;
});
});
// Load Phi tile from global to LDS in column-major format (K x N)
// phi is stored in global memory as [nC, output_dim] row-major
// We need to transpose it to [K, N] column-major for BlockGemm
for(index_t i = tid; i < kKTile * kNTile; i += get_block_size())
{
const index_t local_k = i / kNTile;
const index_t local_n = i % kNTile;
const index_t global_k = k_start + local_k;
const index_t global_n = out_start + local_n;
// Store X tile to LDS
store_tile(x_lds_window, x_tile);
PhiDataType phi_val = 0;
if(local_k < k_len && global_n < output_dim)
{
phi_val = p_phi[global_k * output_dim + global_n];
}
// Store in column-major: phi_lds[n * kKTile + k]
phi_lds[local_n * kKTile + local_k] = phi_val;
}
// Load Phi tile using vectorized load_tile
auto phi_tile = make_static_distributed_tensor<PhiDataType>(phi_load_tile_dist);
load_tile(phi_tile, phi_dram_window);
// Store Phi tile to LDS
store_tile(phi_lds_window, phi_tile);
block_sync_lds();
// Perform manual GEMM: result_acc += x_lds * phi_lds^T
// Distribute work: each thread computes a subset of output elements
// With 64 threads and 16x16 output, each thread handles 4 elements
const index_t total_elements = kMTile * kNTile;
const index_t elements_per_thread =
(total_elements + get_block_size() - 1) / get_block_size();
// Move windows for next iteration
move_tile_window(x_dram_window, {0, kKTile});
move_tile_window(phi_dram_window, {0, kKTile});
for(index_t elem_idx = 0; elem_idx < elements_per_thread; ++elem_idx)
{
const index_t global_elem = tid * elements_per_thread + elem_idx;
if(global_elem < total_elements)
{
const index_t m_idx = global_elem / kNTile;
const index_t n_idx = global_elem % kNTile;
ComputeDataType acc = 0.0f;
for(index_t k_idx = 0; k_idx < kKTile; ++k_idx)
{
ComputeDataType x_val =
type_convert<ComputeDataType>(x_lds[m_idx * kKTile + k_idx]);
ComputeDataType phi_val =
type_convert<ComputeDataType>(phi_lds[n_idx * kKTile + k_idx]);
acc += x_val * phi_val;
}
// Accumulate to shared memory using atomics
atomicAdd(&result_shared[m_idx * kNTile + n_idx], acc);
}
}
// Perform GEMM using BlockGemm with MFMA: result_tile += x_lds * phi_lds^T
BlockGemm{}(result_tile, x_lds_window, phi_lds_window);
block_sync_lds();
}
// Ensure all norm accumulations are complete
// Reduce thread-local norm accumulators using warp shuffle + shared memory
__shared__ ComputeDataType sum_squares_shared[kMTile];
// Initialize shared memory
if(tid < kMTile)
{
sum_squares_shared[tid] = 0.0f;
}
block_sync_lds();
// Warp-level reduction for each batch element
// Since we have 64 threads (1 warp) and kMTile=16, multiple threads contribute to each
// element
constexpr index_t threads_per_element =
kBlockSize / kMTile; // 64/16 = 4 threads per batch element
for(index_t local_m = 0; local_m < kMTile; ++local_m)
{
ComputeDataType my_sum = thread_sum_squares[local_m];
// Warp shuffle reduction within threads handling this batch element
// Threads [local_m*4, local_m*4+1, local_m*4+2, local_m*4+3] reduce together
const index_t my_group = tid / threads_per_element;
const index_t lane_in_group = tid % threads_per_element;
if(my_group == local_m)
{
// Reduce within this group of 4 threads using warp shuffle
#pragma unroll
for(index_t offset = threads_per_element / 2; offset > 0; offset /= 2)
{
my_sum += __shfl_down(my_sum, offset);
}
// First thread in group writes to shared memory
if(lane_in_group == 0)
{
sum_squares_shared[local_m] = my_sum;
}
}
}
block_sync_lds();
// Compute inverse norms after all K-tiles processed
@@ -213,42 +263,71 @@ struct MHCKernelV4
}
}
// Apply normalization, activation, and write output
for(index_t i = tid; i < kMTile * kNTile; i += get_block_size())
{
const index_t local_m = i / kNTile;
const index_t local_n = i % kNTile;
// Apply normalization and activation in-place on result_tile
constexpr auto result_spans = decltype(result_tile)::get_distributed_spans();
sweep_tile_span(result_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(result_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
result_tile.get_tile_distribution(), make_tuple(idx0, idx1));
const index_t global_m = batch_start + local_m;
const index_t global_n = out_start + local_n;
const index_t local_m = tile_idx.at(number<0>{});
const index_t local_n = tile_idx.at(number<1>{});
const index_t global_m = batch_start + local_m;
const index_t global_n = out_start + local_n;
if(global_m >= batch || global_n >= output_dim)
continue;
if(global_m < batch && global_n < output_dim)
{
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const ComputeDataType inv_norm = inv_norms[local_m];
ComputeDataType value = result_tile[i_j_idx];
const ComputeDataType inv_norm = inv_norms[local_m];
ComputeDataType value = result_shared[i];
// Apply normalization and activation based on output section
if(global_n < n)
{
ComputeDataType activated_value;
Activation{}(activated_value, value);
result_tile(i_j_idx) = alpha_pre * inv_norm * activated_value + bias;
}
else if(global_n < 2 * n)
{
ComputeDataType activated_value;
Activation{}(activated_value, value);
result_tile(i_j_idx) =
alpha_post * inv_norm * 2.0f * activated_value + bias;
}
else
{
result_tile(i_j_idx) = alpha_res * inv_norm * value + bias;
}
}
});
});
// Apply normalization and activation based on output section
if(global_n < n)
{
ComputeDataType activated_value;
Activation{}(activated_value, value);
value = alpha_pre * inv_norm * activated_value + bias;
}
else if(global_n < 2 * n)
{
ComputeDataType activated_value;
Activation{}(activated_value, value);
value = alpha_post * inv_norm * 2.0f * activated_value + bias;
}
else
{
value = alpha_res * inv_norm * value + bias;
}
// Cast result to output data type
auto result_output = cast_tile<YDataType>(result_tile);
// Write to global memory
p_output[global_m * output_dim + global_n] = type_convert<YDataType>(value);
}
// Create output tensor view with vectorization for efficient writes
constexpr index_t output_vector_size = 16 / sizeof(YDataType);
auto output_tensor_full =
make_naive_tensor_view<address_space_enum::global>(p_output,
make_tuple(batch, output_dim),
make_tuple(output_dim, 1),
number<output_vector_size>{},
number<1>{});
// Pad output tensor for boundary handling
auto output_tensor_padded = pad_tensor_view(output_tensor_full,
make_tuple(number<kMTile>{}, number<kNTile>{}),
sequence<false, Problem::kPadN>{});
// Create tile window and store using vectorized store_tile
auto output_window = make_tile_window(output_tensor_padded,
make_tuple(number<kMTile>{}, number<kNTile>{}),
{batch_start, out_start},
result_output.get_tile_distribution());
store_tile(output_window, result_output);
}
};

View File

@@ -14,17 +14,33 @@ namespace ck_tile {
struct MHCDefaultPolicy
{
// Provide warp gemm configuration for float data types
// Provide warp gemm configuration for various data types
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
// For float x float -> float, provide a simple configuration
if constexpr(std::is_same_v<typename Problem::ADataType, float> &&
std::is_same_v<typename Problem::BDataType, float> &&
// For bf16 x bf16 -> float (our case), use MFMA-optimized configuration
if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
// Use a simple warp gemm configuration for float
// This is a basic configuration - can be optimized later
// Use MFMA warp gemm for bf16 inputs with float accumulation
using WG = WarpGemmDispatcher<bf16_t,
bf16_t,
float,
16,
16,
16, // M, N, K per warp (MFMA 16x16x16)
true,
false,
false,
WGAttrNumAccessEnum::Single>;
return make_tuple(WG{}, 1, 1); // 1 warp in M, 1 warp in N (K warps handled separately)
}
// For float x float -> float, provide a simple configuration
else if constexpr(std::is_same_v<typename Problem::ADataType, float> &&
std::is_same_v<typename Problem::BDataType, float> &&
std::is_same_v<typename Problem::CDataType, float>)
{
using WG = WarpGemmDispatcher<float,
float,
float,

View File

@@ -77,6 +77,69 @@ struct MHCProblem
};
CK_TILE_HOST static const std::string GetName() { return "MHCProblem"; }
// Helper to derive Generic2dBlockShape from BlockGemmShape
// This ensures BlockShape parameters match our tile sizes
using DerivedBlockShape =
Generic2dBlockShape<sequence<BlockGemmShape::kM, BlockGemmShape::kK>, // BlockTile [M, K]
sequence<BlockGemmShape::kM / VectorSizeA, // ThreadPerBlock [M, K]
BlockGemmShape::kK / VectorSizeA>,
sequence<1, VectorSizeA>>; // Vector [1, VectorSizeA]
// Tile distribution for loading X (input matrix) from global memory
// X is [Batch, nC] row-major, we load kM×kK tiles
// Use BlockShape parameters to ensure consistency
CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution()
{
using namespace ck_tile;
using S = BlockShape;
// For a 2D tile [M, K], we need to define distribution for both dimensions
// Using BlockShape's Repeat, WarpPerBlock, ThreadPerWarp, Vector parameters
using XTileDistEncoding =
tile_distribution_encoding<sequence<>, // R: No replication
tuple<sequence<S::Repeat_M,
S::WarpPerBlock_M,
S::ThreadPerWarp_M,
S::Vector_M>, // H0 (M/Batch dimension)
sequence<S::Repeat_N,
S::WarpPerBlock_N,
S::ThreadPerWarp_N,
S::Vector_N>>, // H1 (K/nC dimension) - using
// N params for K
tuple<sequence<1, 2>>, // P→RH major
tuple<sequence<2, 2>>, // P→RH minor
sequence<2, 2, 1, 1>, // Y→RH major
sequence<0, 3, 0, 3>>; // Y→RH minor
return make_static_tile_distribution(XTileDistEncoding{});
}
// Tile distribution for loading Phi (weight matrix) from global memory
// Phi is [output_dim, nC] row-major, we load kN×kK tiles
CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution()
{
using namespace ck_tile;
using S = BlockShape;
// For Phi [N, K] tile
using PhiTileDistEncoding =
tile_distribution_encoding<sequence<>, // R: No replication
tuple<sequence<S::Repeat_N, // H0 (N/output_dim dimension)
S::WarpPerBlock_N,
S::ThreadPerWarp_N,
S::Vector_N>,
sequence<S::Repeat_N, // H1 (K/nC dimension)
S::WarpPerBlock_N,
S::ThreadPerWarp_N,
S::Vector_N>>,
tuple<sequence<1, 2>>, // P→RH major
tuple<sequence<2, 2>>, // P→RH minor
sequence<2, 2, 1, 1>, // Y→RH major
sequence<0, 3, 0, 3>>; // Y→RH minor
return make_static_tile_distribution(PhiTileDistEncoding{});
}
};
} // namespace ck_tile

View File

@@ -0,0 +1,134 @@
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
// SPDX-License-Identifier: MIT
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp"
namespace ck_tile {
// MHC Problem V4: Simplified version that derives BlockShape from BlockGemmShape
// No need to manually specify BlockShape - it's automatically derived
template <typename XDataType_, typename ComputeDataType_, typename YDataType_>
struct MHCProblemV4
{
using XDataType = remove_cvref_t<XDataType_>;
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
using YDataType = remove_cvref_t<YDataType_>;
// PhiDataType is the same as XDataType for the weight matrix
using PhiDataType = XDataType;
// BlockGemm compatibility - map our types to BlockGemm's expected types
using ADataType = XDataType; // Input matrix A
using BDataType = PhiDataType; // Weight matrix B (phi)
using CDataType = ComputeDataType; // Output/accumulator matrix C
// BlockGemmShape with kM, kN, kK members for BlockGemm
// Using 16x16x16 tiles with 1 warp per block
using BlockGemmShape = TileGemmShape<sequence<16, 16, 16>, // BlockTile (M, N, K)
sequence<1, 1, 1>, // BlockWarps (1 warp per block)
sequence<16, 16, 16>>; // WarpTile (16x16x16 MFMA)
// Vector sizes for loading
static constexpr index_t VectorSizeA = 4;
static constexpr index_t VectorSizeB = 4;
// Derive BlockShape from BlockGemmShape
// Match V3's approach: use a simple 1×64 configuration for 1 warp
// This ensures proper tile distribution for load_tile/store_tile operations
using BlockShape =
Generic2dBlockShape<sequence<1, 64>, // BlockTile [1, 64] - simple layout for 1 warp
sequence<1, 64>, // ThreadPerBlock [1, 64] = 64 threads (1 warp)
sequence<1, 1>>; // Vector [1, 1] - no vectorization in BlockShape
// Layout types for BlockGemm
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
// For GEMM pipeline compatibility
using AsDataTypeTuple = tuple<ADataType>;
using BsDataTypeTuple = tuple<BDataType>;
using AsLayoutTuple = tuple<ALayout>;
using BsLayoutTuple = tuple<BLayout>;
using AElementWise = identity;
using BElementWise = identity;
static constexpr bool TransposeC = false;
static constexpr bool kPadM = true;
static constexpr bool kPadN = true;
static constexpr bool kPadK = true;
static constexpr bool Preshuffle = false;
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
static constexpr index_t NumWaveGroups = 1;
static constexpr index_t VectorLoadSize = 16;
// kBlockSize derived from BlockShape
static constexpr index_t kBlockSize = BlockShape::BlockSize;
// Additional traits
static constexpr bool DoubleSmemBuffer = true;
static constexpr bool UseStructuredSparsity = false;
static constexpr bool FixedVectorSize = false;
struct Traits
{
static constexpr bool UsePersistentKernel = false;
};
CK_TILE_HOST static const std::string GetName() { return "MHCProblemV4"; }
// Tile distribution for loading X (input matrix) from global memory
// X is [Batch, nC] row-major, we load kM×kK tiles (16×16)
// For a 16×16 tile with 64 threads (1 warp):
// M: 1 repeat × 1 warp × 16 threads × 1 vector = 16
// K: 1 repeat × 1 warp × 4 threads × 4 vector = 16
// Total threads: 1 warp × (16×4) = 64 threads ✓
CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution()
{
using namespace ck_tile;
// H0 (M dimension): [repeat=1, warp=1, thread=16, vector=1] = 16
// H1 (K dimension): [repeat=1, warp=1, thread=4, vector=4] = 16
// P→RH: Warp layout = 1 warp in M × 1 warp in K
// Thread layout = 16 threads in M × 4 threads in K = 64 threads/warp
// Y→RH: Access order = M_repeat → M_vector → K_repeat → K_vector (vectorized)
using XTileDistEncoding = tile_distribution_encoding<
sequence<>, // R: No replication
tuple<sequence<1, 1, 16, 1>, // H0 (M): repeat=1, warp=1, thread=16, vector=1
sequence<1, 1, 4, 4>>, // H1 (K): repeat=1, warp=1, thread=4, vector=4
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor
sequence<1, 1, 2, 2>, // Y→RH major
sequence<0, 3, 0, 3>>; // Y→RH minor
return make_static_tile_distribution(XTileDistEncoding{});
}
// Tile distribution for loading Phi (weight matrix) from global memory
// Phi is [output_dim, nC] row-major, we load kN×kK tiles (16×16)
CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution()
{
using namespace ck_tile;
// Same distribution as X for 16×16 tiles
using PhiTileDistEncoding = tile_distribution_encoding<
sequence<>, // R: No replication
tuple<sequence<1, 1, 16, 1>, // H0 (N): repeat=1, warp=1, thread=16, vector=1
sequence<1, 1, 4, 4>>, // H1 (K): repeat=1, warp=1, thread=4, vector=4
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor
sequence<1, 1, 2, 2>, // Y→RH major
sequence<0, 3, 0, 3>>; // Y→RH minor
return make_static_tile_distribution(PhiTileDistEncoding{});
}
};
} // namespace ck_tile