mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-24 23:05:54 +00:00
Add V5: split-k
This commit is contained in:
@@ -7,10 +7,12 @@
|
||||
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v2.hpp"
|
||||
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v3.hpp"
|
||||
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v4.hpp"
|
||||
#include "ck_tile/ops/mhc/kernel/mhc_kernel_tile_v5.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem_v5.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
|
||||
409
include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v5.hpp
Normal file
409
include/ck_tile/ops/mhc/kernel/mhc_kernel_tile_v5.hpp
Normal file
@@ -0,0 +1,409 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp"
|
||||
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1.hpp"
|
||||
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
|
||||
|
||||
// Manifold Constrained Hyper Connection Kernel V5:
|
||||
// =====================================================================
|
||||
// Split-K implementation with 2D grid (B, C):
|
||||
// - Grid dimension 0: Batch tiles (B / kMTile)
|
||||
// - Grid dimension 1: C tiles (nC / kKTile) - split-K dimension
|
||||
// - Each block computes partial GEMM for its C-tile
|
||||
// - Results stored to workspace buffer (no atomics!)
|
||||
// - Separate reduction kernel combines partial results
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename Problem_,
|
||||
typename Policy_ = MHCDefaultPolicy,
|
||||
typename Activation_ = element_wise::Sigmoid>
|
||||
struct MHCKernelV5
|
||||
{
|
||||
using Activation = ck_tile::remove_cvref_t<Activation_>;
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using Policy = ck_tile::remove_cvref_t<Policy_>;
|
||||
|
||||
using XDataType = ck_tile::remove_cvref_t<typename Problem::XDataType>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
using PhiDataType = ck_tile::remove_cvref_t<typename Problem::PhiDataType>;
|
||||
|
||||
// Tile sizes from BlockGemmShape
|
||||
static constexpr index_t kMTile = Problem::BlockGemmShape::kM; // Batch tile (16)
|
||||
static constexpr index_t kNTile = Problem::BlockGemmShape::kN; // Output tile (32)
|
||||
static constexpr index_t kKTile = Problem::BlockGemmShape::kK; // K tile for C dimension (64)
|
||||
|
||||
static constexpr index_t kBlockSize = Problem::kBlockSize;
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return kBlockSize; }
|
||||
|
||||
// Padding to avoid LDS bank conflicts
|
||||
// AMD GPUs have 32 LDS banks, 4-byte bank width
|
||||
// For bf16 (2 bytes), we need padding to avoid stride being multiple of 32
|
||||
static constexpr index_t kKTilePadded = kKTile + 8; // Add 8 elements padding
|
||||
|
||||
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
|
||||
{
|
||||
// LDS for BlockGemm with padding: A[kMTile, kKTile+8] + B[kNTile, kKTile+8]
|
||||
constexpr index_t a_lds_size = kMTile * kKTilePadded * sizeof(XDataType);
|
||||
constexpr index_t b_lds_size = kNTile * kKTilePadded * sizeof(PhiDataType);
|
||||
return a_lds_size + b_lds_size;
|
||||
}
|
||||
|
||||
// Grid configuration: 2D grid (B, C) for split-K
|
||||
CK_TILE_HOST static constexpr auto
|
||||
GetGridSize(index_t batch, [[maybe_unused]] index_t output_dim, index_t nC)
|
||||
{
|
||||
const index_t grid_m = (batch + kMTile - 1) / kMTile;
|
||||
const index_t grid_k = (nC + kKTile - 1) / kKTile; // Split-K dimension
|
||||
return make_tuple(grid_m, grid_k);
|
||||
}
|
||||
|
||||
CK_TILE_DEVICE void operator()(const XDataType* p_x,
|
||||
const PhiDataType* p_phi,
|
||||
ComputeDataType* p_workspace, // [grid_k, batch, output_dim]
|
||||
ComputeDataType* p_partial_norms, // [grid_k, batch]
|
||||
index_t batch,
|
||||
index_t nC,
|
||||
index_t output_dim,
|
||||
[[maybe_unused]] index_t n,
|
||||
[[maybe_unused]] float r = 1.0f,
|
||||
[[maybe_unused]] float alpha_pre = 1.0f,
|
||||
[[maybe_unused]] float alpha_post = 1.0f,
|
||||
[[maybe_unused]] float alpha_res = 1.0f,
|
||||
[[maybe_unused]] float bias = 0.0f) const
|
||||
{
|
||||
// 2D block indexing
|
||||
const index_t grid_m = (batch + kMTile - 1) / kMTile;
|
||||
const index_t block_m = get_block_id() % grid_m;
|
||||
const index_t block_k = get_block_id() / grid_m;
|
||||
|
||||
const index_t batch_start = block_m * kMTile;
|
||||
const index_t k_start = block_k * kKTile;
|
||||
const index_t out_start = 0;
|
||||
|
||||
if(batch_start >= batch || k_start >= nC)
|
||||
return;
|
||||
|
||||
// Allocate shared memory with padding
|
||||
__shared__ char smem_ptr[GetSmemSize()];
|
||||
XDataType* x_lds = reinterpret_cast<XDataType*>(smem_ptr);
|
||||
PhiDataType* phi_lds =
|
||||
reinterpret_cast<PhiDataType*>(smem_ptr + kMTile * kKTilePadded * sizeof(XDataType));
|
||||
|
||||
// Create BlockGemm instance and result tile
|
||||
using BlockGemm = BlockGemmASmemBSmemCRegV1<Problem, Policy>;
|
||||
auto result_tile = BlockGemm::MakeCBlockTile();
|
||||
set_tile(result_tile, 0.0f);
|
||||
|
||||
// Determine actual K size for this block
|
||||
const index_t k_size = ck_tile::min(kKTile, nC - k_start);
|
||||
|
||||
// Create tensor views for X and Phi
|
||||
auto x_tensor_full = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, make_tuple(batch, nC), make_tuple(nC, 1), number<1>{}, number<1>{});
|
||||
|
||||
auto x_tensor_padded = pad_tensor_view(x_tensor_full,
|
||||
make_tuple(number<kMTile>{}, number<kKTile>{}),
|
||||
sequence<false, Problem::kPadK>{});
|
||||
|
||||
constexpr auto x_load_tile_dist = Problem::MakeXLoadTileDistribution();
|
||||
auto x_dram_window = make_tile_window(x_tensor_padded,
|
||||
make_tuple(number<kMTile>{}, number<kKTile>{}),
|
||||
{batch_start, k_start},
|
||||
x_load_tile_dist);
|
||||
|
||||
auto x_lds_tensor = make_naive_tensor_view<address_space_enum::lds>(
|
||||
x_lds,
|
||||
make_tuple(number<kMTile>{}, number<kKTile>{}),
|
||||
make_tuple(number<kKTilePadded>{}, number<1>{}),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
auto x_lds_window =
|
||||
make_tile_window(x_lds_tensor, make_tuple(number<kMTile>{}, number<kKTile>{}), {0, 0});
|
||||
|
||||
// Create Phi tensor view and window
|
||||
auto phi_tensor_full = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_phi, make_tuple(output_dim, nC), make_tuple(1, output_dim), number<1>{}, number<1>{});
|
||||
|
||||
auto phi_tensor_padded = pad_tensor_view(phi_tensor_full,
|
||||
make_tuple(number<kNTile>{}, number<kKTile>{}),
|
||||
sequence<false, Problem::kPadK>{});
|
||||
|
||||
constexpr auto phi_load_tile_dist = Problem::MakePhiLoadTileDistribution();
|
||||
auto phi_dram_window = make_tile_window(phi_tensor_padded,
|
||||
make_tuple(number<kNTile>{}, number<kKTile>{}),
|
||||
{out_start, k_start},
|
||||
phi_load_tile_dist);
|
||||
|
||||
auto phi_lds_tensor = make_naive_tensor_view<address_space_enum::lds>(
|
||||
phi_lds,
|
||||
make_tuple(number<kNTile>{}, number<kKTile>{}),
|
||||
make_tuple(number<kKTilePadded>{}, number<1>{}),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
auto phi_lds_window = make_tile_window(
|
||||
phi_lds_tensor, make_tuple(number<kNTile>{}, number<kKTile>{}), {0, 0});
|
||||
|
||||
// Compute partial norms for this K-tile
|
||||
const index_t thread_id = get_thread_id();
|
||||
constexpr index_t threads_per_row = kBlockSize / kMTile;
|
||||
const index_t row_id = thread_id / threads_per_row;
|
||||
const index_t thread_in_row = thread_id % threads_per_row;
|
||||
|
||||
__shared__ ComputeDataType norm_reduction[kMTile][threads_per_row];
|
||||
|
||||
if(row_id < kMTile)
|
||||
{
|
||||
const index_t global_m = batch_start + row_id;
|
||||
ComputeDataType partial_sum = 0.0f;
|
||||
|
||||
if(global_m < batch)
|
||||
{
|
||||
const XDataType* row_ptr = p_x + global_m * nC + k_start;
|
||||
|
||||
constexpr index_t kVecSize = 4;
|
||||
for(index_t k = thread_in_row * kVecSize; k < k_size;
|
||||
k += threads_per_row * kVecSize)
|
||||
{
|
||||
if(k + kVecSize <= k_size)
|
||||
{
|
||||
using VecType = ext_vector_t<XDataType, kVecSize>;
|
||||
VecType vec = *c_style_pointer_cast<const VecType*>(row_ptr + k);
|
||||
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kVecSize; ++i)
|
||||
{
|
||||
ComputeDataType val = type_convert<ComputeDataType>(vec[i]);
|
||||
partial_sum += val * val;
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
for(index_t i = 0; i < kVecSize && k + i < k_size; ++i)
|
||||
{
|
||||
ComputeDataType val = type_convert<ComputeDataType>(row_ptr[k + i]);
|
||||
partial_sum += val * val;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
norm_reduction[row_id][thread_in_row] = partial_sum;
|
||||
}
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// Reduce and store partial norms to global memory
|
||||
if(thread_in_row == 0 && row_id < kMTile)
|
||||
{
|
||||
const index_t global_m = batch_start + row_id;
|
||||
|
||||
if(global_m < batch)
|
||||
{
|
||||
ComputeDataType sum_squares = 0.0f;
|
||||
#pragma unroll
|
||||
for(index_t t = 0; t < threads_per_row; ++t)
|
||||
{
|
||||
sum_squares += norm_reduction[row_id][t];
|
||||
}
|
||||
|
||||
// Store to global memory: p_partial_norms[block_k, global_m]
|
||||
p_partial_norms[block_k * batch + global_m] = sum_squares;
|
||||
}
|
||||
}
|
||||
|
||||
// Load X tile for this K-slice
|
||||
auto x_tile = make_static_distributed_tensor<XDataType>(x_load_tile_dist);
|
||||
load_tile(x_tile, x_dram_window);
|
||||
store_tile(x_lds_window, x_tile);
|
||||
|
||||
// Load Phi tile for this K-slice
|
||||
auto phi_tile = make_static_distributed_tensor<PhiDataType>(phi_load_tile_dist);
|
||||
load_tile(phi_tile, phi_dram_window);
|
||||
store_tile(phi_lds_window, phi_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// Perform GEMM for this K-slice: result_tile = x_lds * phi_lds^T
|
||||
// Note: This is a partial result for just this K-tile
|
||||
BlockGemm{}(result_tile, x_lds_window, phi_lds_window);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// Store partial results to workspace buffer: p_workspace[block_k, batch, output_dim]
|
||||
// Layout: [grid_k][batch][output_dim]
|
||||
constexpr auto result_spans = decltype(result_tile)::get_distributed_spans();
|
||||
sweep_tile_span(result_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(result_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
result_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const index_t local_m = tile_idx.at(number<0>{});
|
||||
const index_t local_n = tile_idx.at(number<1>{});
|
||||
const index_t global_m = batch_start + local_m;
|
||||
const index_t global_n = out_start + local_n;
|
||||
|
||||
if(global_m < batch && global_n < output_dim)
|
||||
{
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
ComputeDataType value = result_tile[i_j_idx];
|
||||
|
||||
// Store to workspace: [block_k][global_m][global_n]
|
||||
const index_t workspace_idx =
|
||||
block_k * (batch * output_dim) + global_m * output_dim + global_n;
|
||||
p_workspace[workspace_idx] = value;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// Optimized reduction kernel with block-level shared memory reduction
|
||||
template <typename Problem_, typename Activation_ = element_wise::Sigmoid>
|
||||
struct MHCReductionKernel
|
||||
{
|
||||
using Activation = ck_tile::remove_cvref_t<Activation_>;
|
||||
using Problem = ck_tile::remove_cvref_t<Problem_>;
|
||||
using ComputeDataType = ck_tile::remove_cvref_t<typename Problem::ComputeDataType>;
|
||||
using YDataType = ck_tile::remove_cvref_t<typename Problem::YDataType>;
|
||||
|
||||
static constexpr index_t kBlockSize = 256;
|
||||
static constexpr index_t kVecSize = 4; // Vectorized loads
|
||||
|
||||
CK_TILE_HOST static constexpr auto BlockSize() { return kBlockSize; }
|
||||
|
||||
CK_TILE_DEVICE void operator()(const ComputeDataType* p_workspace,
|
||||
const ComputeDataType* p_partial_norms,
|
||||
YDataType* p_output,
|
||||
index_t batch,
|
||||
index_t nC,
|
||||
index_t output_dim,
|
||||
index_t n,
|
||||
index_t grid_k,
|
||||
float alpha_pre,
|
||||
float alpha_post,
|
||||
float alpha_res,
|
||||
float bias) const
|
||||
{
|
||||
const index_t tid = get_thread_id();
|
||||
const index_t block_id = get_block_id();
|
||||
const index_t block_size = get_block_size();
|
||||
|
||||
// Each block processes multiple output elements
|
||||
// Use block-level reduction for better memory coalescing
|
||||
const index_t elements_per_block = block_size;
|
||||
const index_t global_start = block_id * elements_per_block;
|
||||
const index_t total_elements = batch * output_dim;
|
||||
|
||||
const index_t global_idx = global_start + tid;
|
||||
|
||||
if(global_idx >= total_elements)
|
||||
return;
|
||||
|
||||
const index_t global_m = global_idx / output_dim;
|
||||
const index_t global_n = global_idx % output_dim;
|
||||
|
||||
// Reduce partial norms with vectorized loads where possible
|
||||
ComputeDataType sum_squares = 0.0f;
|
||||
const index_t norm_base = global_m;
|
||||
|
||||
// Vectorized reduction for norms
|
||||
index_t k = 0;
|
||||
for(; k + kVecSize <= grid_k; k += kVecSize)
|
||||
{
|
||||
using VecType = ext_vector_t<ComputeDataType, kVecSize>;
|
||||
VecType vec_norms;
|
||||
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kVecSize; ++i)
|
||||
{
|
||||
vec_norms[i] = p_partial_norms[(k + i) * batch + norm_base];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kVecSize; ++i)
|
||||
{
|
||||
sum_squares += vec_norms[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Handle remaining elements
|
||||
for(; k < grid_k; ++k)
|
||||
{
|
||||
sum_squares += p_partial_norms[k * batch + norm_base];
|
||||
}
|
||||
|
||||
const ComputeDataType sqrt_nC = ck_tile::sqrt(static_cast<ComputeDataType>(nC));
|
||||
ComputeDataType norm = ck_tile::sqrt(sum_squares) / sqrt_nC;
|
||||
norm = (norm > 1e-12f) ? norm : 1.0f;
|
||||
|
||||
// Reduce partial GEMM results with improved memory access pattern
|
||||
// Reorganize to improve coalescing: threads in a warp access consecutive elements
|
||||
ComputeDataType value = 0.0f;
|
||||
const index_t workspace_stride = batch * output_dim;
|
||||
|
||||
// Vectorized reduction for workspace
|
||||
k = 0;
|
||||
for(; k + kVecSize <= grid_k; k += kVecSize)
|
||||
{
|
||||
using VecType = ext_vector_t<ComputeDataType, kVecSize>;
|
||||
VecType vec_values;
|
||||
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kVecSize; ++i)
|
||||
{
|
||||
const index_t workspace_idx = (k + i) * workspace_stride + global_idx;
|
||||
vec_values[i] = p_workspace[workspace_idx];
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
for(index_t i = 0; i < kVecSize; ++i)
|
||||
{
|
||||
value += vec_values[i];
|
||||
}
|
||||
}
|
||||
|
||||
// Handle remaining elements
|
||||
for(; k < grid_k; ++k)
|
||||
{
|
||||
const index_t workspace_idx = k * workspace_stride + global_idx;
|
||||
value += p_workspace[workspace_idx];
|
||||
}
|
||||
|
||||
// Apply normalization and activation based on output section
|
||||
ComputeDataType final_value;
|
||||
if(global_n < n)
|
||||
{
|
||||
// Pre-activation section
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
final_value = (alpha_pre / norm) * activated_value + bias;
|
||||
}
|
||||
else if(global_n < 2 * n)
|
||||
{
|
||||
// Post-activation section
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
final_value = (alpha_post / norm) * 2.0f * activated_value + bias;
|
||||
}
|
||||
else
|
||||
{
|
||||
// Residual section
|
||||
final_value = (alpha_res / norm) * value + bias;
|
||||
}
|
||||
|
||||
p_output[global_idx] = type_convert<YDataType>(final_value);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
126
include/ck_tile/ops/mhc/pipeline/mhc_problem_v5.hpp
Normal file
126
include/ck_tile/ops/mhc/pipeline/mhc_problem_v5.hpp
Normal file
@@ -0,0 +1,126 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// MHC Problem V5: Optimized for large C values with split-K
|
||||
// Adaptive M tile size based on batch size for optimal performance
|
||||
template <typename XDataType_,
|
||||
typename ComputeDataType_,
|
||||
typename YDataType_,
|
||||
index_t MTile_ = 16> // Default M=16 for small/medium batches
|
||||
struct MHCProblemV5
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
|
||||
using PhiDataType = XDataType;
|
||||
|
||||
// BlockGemm compatibility
|
||||
using ADataType = XDataType;
|
||||
using BDataType = PhiDataType;
|
||||
using CDataType = ComputeDataType;
|
||||
|
||||
static constexpr index_t kMTile = MTile_; // Adaptive M tile size
|
||||
|
||||
// Adaptive tile configuration
|
||||
// M=16 (default): Optimal for small/medium batches (B < 4096)
|
||||
// M=64: Optimal for large batches (B >= 4096)
|
||||
// N=32, K=128: Fixed for all configurations
|
||||
using BlockGemmShape = TileGemmShape<sequence<MTile_, 32, 128>, // BlockTile: Adaptive M
|
||||
sequence<1, 1, 1>, // BlockWarps: 1 warp
|
||||
sequence<MTile_, 32, 128>>; // WarpTile: matches BlockTile
|
||||
|
||||
static constexpr index_t VectorSizeA = 4;
|
||||
static constexpr index_t VectorSizeB = 4;
|
||||
|
||||
// 1 warp × 64 threads/warp = 64 threads (same as V4)
|
||||
using BlockShape = Generic2dBlockShape<sequence<1, 64>, sequence<1, 64>, sequence<1, 1>>;
|
||||
|
||||
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
using AsDataTypeTuple = tuple<ADataType>;
|
||||
using BsDataTypeTuple = tuple<BDataType>;
|
||||
using AsLayoutTuple = tuple<ALayout>;
|
||||
using BsLayoutTuple = tuple<BLayout>;
|
||||
|
||||
using AElementWise = identity;
|
||||
using BElementWise = identity;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
static constexpr bool kPadK = true;
|
||||
static constexpr bool Preshuffle = false;
|
||||
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
|
||||
static constexpr index_t VectorLoadSize = 16;
|
||||
static constexpr index_t kBlockSize = BlockShape::BlockSize;
|
||||
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr bool FixedVectorSize = false;
|
||||
|
||||
struct Traits
|
||||
{
|
||||
static constexpr bool UsePersistentKernel = false;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static const std::string GetName() { return "MHCProblemV5"; }
|
||||
|
||||
// Adaptive tile distribution for loading X (input matrix)
|
||||
// X is [Batch, nC] row-major, we load kM×kK tiles
|
||||
// For M=16: H0 (M): [grid=1, warp=1, thread=16, vector=1] = 16
|
||||
// For M=64: H0 (M): [grid=4, warp=1, thread=16, vector=1] = 64
|
||||
// H1 (K): [grid=2, warp=1, thread=4, vector=16] = 128 (same for all)
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
constexpr index_t m_grid = MTile_ / 16; // M=16 → grid=1, M=64 → grid=4
|
||||
|
||||
using XTileDistEncoding = tile_distribution_encoding<
|
||||
sequence<>, // R: No replication
|
||||
tuple<sequence<m_grid, 1, 16, 1>, // H0 (M): adaptive grid based on MTile_
|
||||
sequence<2, 1, 4, 16>>, // H1 (K): grid=2, warp=1, thread=4, vector=16
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major: warp arrangement
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor: thread arrangement
|
||||
sequence<1, 1, 2, 2>, // Y→RH major: data layout
|
||||
sequence<0, 3, 0, 3>>; // Y→RH minor: vectorization
|
||||
|
||||
return make_static_tile_distribution(XTileDistEncoding{});
|
||||
}
|
||||
|
||||
// Tile distribution for loading Phi (weight matrix)
|
||||
// Phi is [output_dim, nC] row-major, we load kN×kK tiles (32×128)
|
||||
// H0 (N): [grid=1, warp=1, thread=16, vector=2] = 32
|
||||
// H1 (K): [grid=2, warp=1, thread=4, vector=16] = 128
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
using PhiTileDistEncoding = tile_distribution_encoding<
|
||||
sequence<>, // R: No replication
|
||||
tuple<sequence<1, 1, 16, 2>, // H0 (N): grid=1, warp=1, thread=16, vector=2
|
||||
sequence<2, 1, 4, 16>>, // H1 (K): grid=2, warp=1, thread=4, vector=16
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major: warp arrangement
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor: thread arrangement
|
||||
sequence<1, 1, 2, 2>, // Y→RH major: data layout
|
||||
sequence<0, 3, 0, 3>>; // Y→RH minor: vectorization
|
||||
|
||||
return make_static_tile_distribution(PhiTileDistEncoding{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user