mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
WIP: v4 tile distribution working
This commit is contained in:
@@ -10,6 +10,7 @@
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_default_policy.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_shape.hpp"
|
||||
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
|
||||
#include "ck_tile/ops/common/load_interleaved_pk_type.hpp"
|
||||
|
||||
@@ -92,108 +92,158 @@ struct MHCKernelV4
|
||||
PhiDataType* phi_lds =
|
||||
reinterpret_cast<PhiDataType*>(smem_ptr + kMTile * kKTile * sizeof(XDataType));
|
||||
|
||||
// Shared memory for norm accumulation (one per batch element in tile)
|
||||
__shared__ ComputeDataType sum_squares_shared[kMTile];
|
||||
|
||||
// Shared memory for GEMM result accumulation
|
||||
__shared__ ComputeDataType result_shared[kMTile * kNTile];
|
||||
|
||||
// Initialize shared norm accumulators and result
|
||||
for(index_t i = tid; i < kMTile; i += get_block_size())
|
||||
// Thread-local norm accumulation (one per batch element in tile)
|
||||
// Each thread accumulates for the elements it processes
|
||||
ComputeDataType thread_sum_squares[kMTile];
|
||||
for(index_t i = 0; i < kMTile; ++i)
|
||||
{
|
||||
sum_squares_shared[i] = 0.0f;
|
||||
thread_sum_squares[i] = 0.0f;
|
||||
}
|
||||
for(index_t i = tid; i < kMTile * kNTile; i += get_block_size())
|
||||
{
|
||||
result_shared[i] = 0.0f;
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// Create BlockGemm instance and result tile (distributed tensor in registers)
|
||||
using BlockGemm = BlockGemmASmemBSmemCRegV1<Problem, Policy>;
|
||||
auto result_tile = BlockGemm::MakeCBlockTile();
|
||||
set_tile(result_tile, 0.0f);
|
||||
|
||||
// Number of K-tile iterations
|
||||
const index_t num_k_tiles = (nC + kKTile - 1) / kKTile;
|
||||
|
||||
// Main loop: load tiles, compute norms incrementally, and accumulate GEMM
|
||||
// Create tensor views for X and Phi
|
||||
auto x_tensor_full = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_x, make_tuple(batch, nC), make_tuple(nC, 1), number<1>{}, number<1>{});
|
||||
|
||||
auto x_tensor_padded = pad_tensor_view(x_tensor_full,
|
||||
make_tuple(number<kMTile>{}, number<kKTile>{}),
|
||||
sequence<false, Problem::kPadK>{});
|
||||
|
||||
// Create X DRAM window with tile distribution for vectorized loading
|
||||
constexpr auto x_load_tile_dist = Problem::MakeXLoadTileDistribution();
|
||||
auto x_dram_window = make_tile_window(x_tensor_padded,
|
||||
make_tuple(number<kMTile>{}, number<kKTile>{}),
|
||||
{batch_start, 0},
|
||||
x_load_tile_dist);
|
||||
|
||||
// Create X LDS tensor view and window
|
||||
auto x_lds_tensor = make_naive_tensor_view<address_space_enum::lds>(
|
||||
x_lds,
|
||||
make_tuple(number<kMTile>{}, number<kKTile>{}),
|
||||
make_tuple(number<kKTile>{}, number<1>{}),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
auto x_lds_window =
|
||||
make_tile_window(x_lds_tensor, make_tuple(number<kMTile>{}, number<kKTile>{}), {0, 0});
|
||||
|
||||
// Create Phi tensor view and window with tile distribution
|
||||
auto phi_tensor_full = make_naive_tensor_view<address_space_enum::global>(
|
||||
p_phi, make_tuple(output_dim, nC), make_tuple(1, output_dim), number<1>{}, number<1>{});
|
||||
|
||||
auto phi_tensor_padded = pad_tensor_view(phi_tensor_full,
|
||||
make_tuple(number<kNTile>{}, number<kKTile>{}),
|
||||
sequence<false, Problem::kPadK>{});
|
||||
|
||||
constexpr auto phi_load_tile_dist = Problem::MakePhiLoadTileDistribution();
|
||||
auto phi_dram_window = make_tile_window(phi_tensor_padded,
|
||||
make_tuple(number<kNTile>{}, number<kKTile>{}),
|
||||
{out_start, 0},
|
||||
phi_load_tile_dist);
|
||||
|
||||
// Create Phi LDS tensor view and window
|
||||
auto phi_lds_tensor = make_naive_tensor_view<address_space_enum::lds>(
|
||||
phi_lds,
|
||||
make_tuple(number<kNTile>{}, number<kKTile>{}),
|
||||
make_tuple(number<kKTile>{}, number<1>{}),
|
||||
number<1>{},
|
||||
number<1>{});
|
||||
|
||||
auto phi_lds_window = make_tile_window(
|
||||
phi_lds_tensor, make_tuple(number<kNTile>{}, number<kKTile>{}), {0, 0});
|
||||
|
||||
// Main loop: load tiles with vectorization, compute norms, and accumulate GEMM
|
||||
for(index_t k_tile_idx = 0; k_tile_idx < num_k_tiles; ++k_tile_idx)
|
||||
{
|
||||
const index_t k_start = k_tile_idx * kKTile;
|
||||
const index_t k_end = min(k_start + kKTile, nC);
|
||||
const index_t k_len = k_end - k_start;
|
||||
// Load X tile using vectorized load_tile
|
||||
auto x_tile = make_static_distributed_tensor<XDataType>(x_load_tile_dist);
|
||||
load_tile(x_tile, x_dram_window);
|
||||
|
||||
// Load X tile from global to LDS and accumulate norm
|
||||
for(index_t i = tid; i < kMTile * kKTile; i += get_block_size())
|
||||
{
|
||||
const index_t local_m = i / kKTile;
|
||||
const index_t local_k = i % kKTile;
|
||||
const index_t global_m = batch_start + local_m;
|
||||
const index_t global_k = k_start + local_k;
|
||||
// Accumulate norms from the loaded tile into thread-local storage
|
||||
constexpr auto x_tile_spans = decltype(x_tile)::get_distributed_spans();
|
||||
sweep_tile_span(x_tile_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(x_tile_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
x_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
XDataType x_val = 0;
|
||||
if(global_m < batch && local_k < k_len)
|
||||
{
|
||||
x_val = p_x[global_m * nC + global_k];
|
||||
const index_t local_m = tile_idx.at(number<0>{});
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
|
||||
// Accumulate norm for this batch element using atomics
|
||||
ComputeDataType x_compute = type_convert<ComputeDataType>(x_val);
|
||||
ComputeDataType sq = x_compute * x_compute;
|
||||
atomicAdd(&sum_squares_shared[local_m], sq);
|
||||
}
|
||||
x_lds[i] = x_val;
|
||||
}
|
||||
ComputeDataType x_val = type_convert<ComputeDataType>(x_tile[i_j_idx]);
|
||||
thread_sum_squares[local_m] += x_val * x_val;
|
||||
});
|
||||
});
|
||||
|
||||
// Load Phi tile from global to LDS in column-major format (K x N)
|
||||
// phi is stored in global memory as [nC, output_dim] row-major
|
||||
// We need to transpose it to [K, N] column-major for BlockGemm
|
||||
for(index_t i = tid; i < kKTile * kNTile; i += get_block_size())
|
||||
{
|
||||
const index_t local_k = i / kNTile;
|
||||
const index_t local_n = i % kNTile;
|
||||
const index_t global_k = k_start + local_k;
|
||||
const index_t global_n = out_start + local_n;
|
||||
// Store X tile to LDS
|
||||
store_tile(x_lds_window, x_tile);
|
||||
|
||||
PhiDataType phi_val = 0;
|
||||
if(local_k < k_len && global_n < output_dim)
|
||||
{
|
||||
phi_val = p_phi[global_k * output_dim + global_n];
|
||||
}
|
||||
// Store in column-major: phi_lds[n * kKTile + k]
|
||||
phi_lds[local_n * kKTile + local_k] = phi_val;
|
||||
}
|
||||
// Load Phi tile using vectorized load_tile
|
||||
auto phi_tile = make_static_distributed_tensor<PhiDataType>(phi_load_tile_dist);
|
||||
load_tile(phi_tile, phi_dram_window);
|
||||
|
||||
// Store Phi tile to LDS
|
||||
store_tile(phi_lds_window, phi_tile);
|
||||
|
||||
block_sync_lds();
|
||||
|
||||
// Perform manual GEMM: result_acc += x_lds * phi_lds^T
|
||||
// Distribute work: each thread computes a subset of output elements
|
||||
// With 64 threads and 16x16 output, each thread handles 4 elements
|
||||
const index_t total_elements = kMTile * kNTile;
|
||||
const index_t elements_per_thread =
|
||||
(total_elements + get_block_size() - 1) / get_block_size();
|
||||
// Move windows for next iteration
|
||||
move_tile_window(x_dram_window, {0, kKTile});
|
||||
move_tile_window(phi_dram_window, {0, kKTile});
|
||||
|
||||
for(index_t elem_idx = 0; elem_idx < elements_per_thread; ++elem_idx)
|
||||
{
|
||||
const index_t global_elem = tid * elements_per_thread + elem_idx;
|
||||
if(global_elem < total_elements)
|
||||
{
|
||||
const index_t m_idx = global_elem / kNTile;
|
||||
const index_t n_idx = global_elem % kNTile;
|
||||
|
||||
ComputeDataType acc = 0.0f;
|
||||
for(index_t k_idx = 0; k_idx < kKTile; ++k_idx)
|
||||
{
|
||||
ComputeDataType x_val =
|
||||
type_convert<ComputeDataType>(x_lds[m_idx * kKTile + k_idx]);
|
||||
ComputeDataType phi_val =
|
||||
type_convert<ComputeDataType>(phi_lds[n_idx * kKTile + k_idx]);
|
||||
acc += x_val * phi_val;
|
||||
}
|
||||
// Accumulate to shared memory using atomics
|
||||
atomicAdd(&result_shared[m_idx * kNTile + n_idx], acc);
|
||||
}
|
||||
}
|
||||
// Perform GEMM using BlockGemm with MFMA: result_tile += x_lds * phi_lds^T
|
||||
BlockGemm{}(result_tile, x_lds_window, phi_lds_window);
|
||||
|
||||
block_sync_lds();
|
||||
}
|
||||
|
||||
// Ensure all norm accumulations are complete
|
||||
// Reduce thread-local norm accumulators using warp shuffle + shared memory
|
||||
__shared__ ComputeDataType sum_squares_shared[kMTile];
|
||||
|
||||
// Initialize shared memory
|
||||
if(tid < kMTile)
|
||||
{
|
||||
sum_squares_shared[tid] = 0.0f;
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// Warp-level reduction for each batch element
|
||||
// Since we have 64 threads (1 warp) and kMTile=16, multiple threads contribute to each
|
||||
// element
|
||||
constexpr index_t threads_per_element =
|
||||
kBlockSize / kMTile; // 64/16 = 4 threads per batch element
|
||||
|
||||
for(index_t local_m = 0; local_m < kMTile; ++local_m)
|
||||
{
|
||||
ComputeDataType my_sum = thread_sum_squares[local_m];
|
||||
|
||||
// Warp shuffle reduction within threads handling this batch element
|
||||
// Threads [local_m*4, local_m*4+1, local_m*4+2, local_m*4+3] reduce together
|
||||
const index_t my_group = tid / threads_per_element;
|
||||
const index_t lane_in_group = tid % threads_per_element;
|
||||
|
||||
if(my_group == local_m)
|
||||
{
|
||||
// Reduce within this group of 4 threads using warp shuffle
|
||||
#pragma unroll
|
||||
for(index_t offset = threads_per_element / 2; offset > 0; offset /= 2)
|
||||
{
|
||||
my_sum += __shfl_down(my_sum, offset);
|
||||
}
|
||||
|
||||
// First thread in group writes to shared memory
|
||||
if(lane_in_group == 0)
|
||||
{
|
||||
sum_squares_shared[local_m] = my_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
block_sync_lds();
|
||||
|
||||
// Compute inverse norms after all K-tiles processed
|
||||
@@ -213,42 +263,71 @@ struct MHCKernelV4
|
||||
}
|
||||
}
|
||||
|
||||
// Apply normalization, activation, and write output
|
||||
for(index_t i = tid; i < kMTile * kNTile; i += get_block_size())
|
||||
{
|
||||
const index_t local_m = i / kNTile;
|
||||
const index_t local_n = i % kNTile;
|
||||
// Apply normalization and activation in-place on result_tile
|
||||
constexpr auto result_spans = decltype(result_tile)::get_distributed_spans();
|
||||
sweep_tile_span(result_spans[number<0>{}], [&](auto idx0) {
|
||||
sweep_tile_span(result_spans[number<1>{}], [&](auto idx1) {
|
||||
const auto tile_idx = get_x_indices_from_distributed_indices(
|
||||
result_tile.get_tile_distribution(), make_tuple(idx0, idx1));
|
||||
|
||||
const index_t global_m = batch_start + local_m;
|
||||
const index_t global_n = out_start + local_n;
|
||||
const index_t local_m = tile_idx.at(number<0>{});
|
||||
const index_t local_n = tile_idx.at(number<1>{});
|
||||
const index_t global_m = batch_start + local_m;
|
||||
const index_t global_n = out_start + local_n;
|
||||
|
||||
if(global_m >= batch || global_n >= output_dim)
|
||||
continue;
|
||||
if(global_m < batch && global_n < output_dim)
|
||||
{
|
||||
constexpr auto i_j_idx = make_tuple(idx0, idx1);
|
||||
const ComputeDataType inv_norm = inv_norms[local_m];
|
||||
ComputeDataType value = result_tile[i_j_idx];
|
||||
|
||||
const ComputeDataType inv_norm = inv_norms[local_m];
|
||||
ComputeDataType value = result_shared[i];
|
||||
// Apply normalization and activation based on output section
|
||||
if(global_n < n)
|
||||
{
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
result_tile(i_j_idx) = alpha_pre * inv_norm * activated_value + bias;
|
||||
}
|
||||
else if(global_n < 2 * n)
|
||||
{
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
result_tile(i_j_idx) =
|
||||
alpha_post * inv_norm * 2.0f * activated_value + bias;
|
||||
}
|
||||
else
|
||||
{
|
||||
result_tile(i_j_idx) = alpha_res * inv_norm * value + bias;
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// Apply normalization and activation based on output section
|
||||
if(global_n < n)
|
||||
{
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
value = alpha_pre * inv_norm * activated_value + bias;
|
||||
}
|
||||
else if(global_n < 2 * n)
|
||||
{
|
||||
ComputeDataType activated_value;
|
||||
Activation{}(activated_value, value);
|
||||
value = alpha_post * inv_norm * 2.0f * activated_value + bias;
|
||||
}
|
||||
else
|
||||
{
|
||||
value = alpha_res * inv_norm * value + bias;
|
||||
}
|
||||
// Cast result to output data type
|
||||
auto result_output = cast_tile<YDataType>(result_tile);
|
||||
|
||||
// Write to global memory
|
||||
p_output[global_m * output_dim + global_n] = type_convert<YDataType>(value);
|
||||
}
|
||||
// Create output tensor view with vectorization for efficient writes
|
||||
constexpr index_t output_vector_size = 16 / sizeof(YDataType);
|
||||
|
||||
auto output_tensor_full =
|
||||
make_naive_tensor_view<address_space_enum::global>(p_output,
|
||||
make_tuple(batch, output_dim),
|
||||
make_tuple(output_dim, 1),
|
||||
number<output_vector_size>{},
|
||||
number<1>{});
|
||||
|
||||
// Pad output tensor for boundary handling
|
||||
auto output_tensor_padded = pad_tensor_view(output_tensor_full,
|
||||
make_tuple(number<kMTile>{}, number<kNTile>{}),
|
||||
sequence<false, Problem::kPadN>{});
|
||||
|
||||
// Create tile window and store using vectorized store_tile
|
||||
auto output_window = make_tile_window(output_tensor_padded,
|
||||
make_tuple(number<kMTile>{}, number<kNTile>{}),
|
||||
{batch_start, out_start},
|
||||
result_output.get_tile_distribution());
|
||||
|
||||
store_tile(output_window, result_output);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -14,17 +14,33 @@ namespace ck_tile {
|
||||
struct MHCDefaultPolicy
|
||||
{
|
||||
|
||||
// Provide warp gemm configuration for float data types
|
||||
// Provide warp gemm configuration for various data types
|
||||
template <typename Problem>
|
||||
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
|
||||
{
|
||||
// For float x float -> float, provide a simple configuration
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, float> &&
|
||||
std::is_same_v<typename Problem::BDataType, float> &&
|
||||
// For bf16 x bf16 -> float (our case), use MFMA-optimized configuration
|
||||
if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::BDataType, bf16_t> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
// Use a simple warp gemm configuration for float
|
||||
// This is a basic configuration - can be optimized later
|
||||
// Use MFMA warp gemm for bf16 inputs with float accumulation
|
||||
using WG = WarpGemmDispatcher<bf16_t,
|
||||
bf16_t,
|
||||
float,
|
||||
16,
|
||||
16,
|
||||
16, // M, N, K per warp (MFMA 16x16x16)
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
WGAttrNumAccessEnum::Single>;
|
||||
return make_tuple(WG{}, 1, 1); // 1 warp in M, 1 warp in N (K warps handled separately)
|
||||
}
|
||||
// For float x float -> float, provide a simple configuration
|
||||
else if constexpr(std::is_same_v<typename Problem::ADataType, float> &&
|
||||
std::is_same_v<typename Problem::BDataType, float> &&
|
||||
std::is_same_v<typename Problem::CDataType, float>)
|
||||
{
|
||||
using WG = WarpGemmDispatcher<float,
|
||||
float,
|
||||
float,
|
||||
|
||||
@@ -77,6 +77,69 @@ struct MHCProblem
|
||||
};
|
||||
|
||||
CK_TILE_HOST static const std::string GetName() { return "MHCProblem"; }
|
||||
|
||||
// Helper to derive Generic2dBlockShape from BlockGemmShape
|
||||
// This ensures BlockShape parameters match our tile sizes
|
||||
using DerivedBlockShape =
|
||||
Generic2dBlockShape<sequence<BlockGemmShape::kM, BlockGemmShape::kK>, // BlockTile [M, K]
|
||||
sequence<BlockGemmShape::kM / VectorSizeA, // ThreadPerBlock [M, K]
|
||||
BlockGemmShape::kK / VectorSizeA>,
|
||||
sequence<1, VectorSizeA>>; // Vector [1, VectorSizeA]
|
||||
|
||||
// Tile distribution for loading X (input matrix) from global memory
|
||||
// X is [Batch, nC] row-major, we load kM×kK tiles
|
||||
// Use BlockShape parameters to ensure consistency
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
using S = BlockShape;
|
||||
|
||||
// For a 2D tile [M, K], we need to define distribution for both dimensions
|
||||
// Using BlockShape's Repeat, WarpPerBlock, ThreadPerWarp, Vector parameters
|
||||
using XTileDistEncoding =
|
||||
tile_distribution_encoding<sequence<>, // R: No replication
|
||||
tuple<sequence<S::Repeat_M,
|
||||
S::WarpPerBlock_M,
|
||||
S::ThreadPerWarp_M,
|
||||
S::Vector_M>, // H0 (M/Batch dimension)
|
||||
sequence<S::Repeat_N,
|
||||
S::WarpPerBlock_N,
|
||||
S::ThreadPerWarp_N,
|
||||
S::Vector_N>>, // H1 (K/nC dimension) - using
|
||||
// N params for K
|
||||
tuple<sequence<1, 2>>, // P→RH major
|
||||
tuple<sequence<2, 2>>, // P→RH minor
|
||||
sequence<2, 2, 1, 1>, // Y→RH major
|
||||
sequence<0, 3, 0, 3>>; // Y→RH minor
|
||||
|
||||
return make_static_tile_distribution(XTileDistEncoding{});
|
||||
}
|
||||
|
||||
// Tile distribution for loading Phi (weight matrix) from global memory
|
||||
// Phi is [output_dim, nC] row-major, we load kN×kK tiles
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
using S = BlockShape;
|
||||
|
||||
// For Phi [N, K] tile
|
||||
using PhiTileDistEncoding =
|
||||
tile_distribution_encoding<sequence<>, // R: No replication
|
||||
tuple<sequence<S::Repeat_N, // H0 (N/output_dim dimension)
|
||||
S::WarpPerBlock_N,
|
||||
S::ThreadPerWarp_N,
|
||||
S::Vector_N>,
|
||||
sequence<S::Repeat_N, // H1 (K/nC dimension)
|
||||
S::WarpPerBlock_N,
|
||||
S::ThreadPerWarp_N,
|
||||
S::Vector_N>>,
|
||||
tuple<sequence<1, 2>>, // P→RH major
|
||||
tuple<sequence<2, 2>>, // P→RH minor
|
||||
sequence<2, 2, 1, 1>, // Y→RH major
|
||||
sequence<0, 3, 0, 3>>; // Y→RH minor
|
||||
|
||||
return make_static_tile_distribution(PhiTileDistEncoding{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
|
||||
134
include/ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp
Normal file
134
include/ck_tile/ops/mhc/pipeline/mhc_problem_v4.hpp
Normal file
@@ -0,0 +1,134 @@
|
||||
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/common/tensor_layout.hpp"
|
||||
#include "ck_tile/ops/mhc/pipeline/mhc_gemm_shape.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
// MHC Problem V4: Simplified version that derives BlockShape from BlockGemmShape
|
||||
// No need to manually specify BlockShape - it's automatically derived
|
||||
template <typename XDataType_, typename ComputeDataType_, typename YDataType_>
|
||||
struct MHCProblemV4
|
||||
{
|
||||
using XDataType = remove_cvref_t<XDataType_>;
|
||||
using ComputeDataType = remove_cvref_t<ComputeDataType_>;
|
||||
using YDataType = remove_cvref_t<YDataType_>;
|
||||
|
||||
// PhiDataType is the same as XDataType for the weight matrix
|
||||
using PhiDataType = XDataType;
|
||||
|
||||
// BlockGemm compatibility - map our types to BlockGemm's expected types
|
||||
using ADataType = XDataType; // Input matrix A
|
||||
using BDataType = PhiDataType; // Weight matrix B (phi)
|
||||
using CDataType = ComputeDataType; // Output/accumulator matrix C
|
||||
|
||||
// BlockGemmShape with kM, kN, kK members for BlockGemm
|
||||
// Using 16x16x16 tiles with 1 warp per block
|
||||
using BlockGemmShape = TileGemmShape<sequence<16, 16, 16>, // BlockTile (M, N, K)
|
||||
sequence<1, 1, 1>, // BlockWarps (1 warp per block)
|
||||
sequence<16, 16, 16>>; // WarpTile (16x16x16 MFMA)
|
||||
|
||||
// Vector sizes for loading
|
||||
static constexpr index_t VectorSizeA = 4;
|
||||
static constexpr index_t VectorSizeB = 4;
|
||||
|
||||
// Derive BlockShape from BlockGemmShape
|
||||
// Match V3's approach: use a simple 1×64 configuration for 1 warp
|
||||
// This ensures proper tile distribution for load_tile/store_tile operations
|
||||
using BlockShape =
|
||||
Generic2dBlockShape<sequence<1, 64>, // BlockTile [1, 64] - simple layout for 1 warp
|
||||
sequence<1, 64>, // ThreadPerBlock [1, 64] = 64 threads (1 warp)
|
||||
sequence<1, 1>>; // Vector [1, 1] - no vectorization in BlockShape
|
||||
|
||||
// Layout types for BlockGemm
|
||||
using ALayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
using BLayout = ck_tile::tensor_layout::gemm::ColumnMajor;
|
||||
using CLayout = ck_tile::tensor_layout::gemm::RowMajor;
|
||||
|
||||
// For GEMM pipeline compatibility
|
||||
using AsDataTypeTuple = tuple<ADataType>;
|
||||
using BsDataTypeTuple = tuple<BDataType>;
|
||||
using AsLayoutTuple = tuple<ALayout>;
|
||||
using BsLayoutTuple = tuple<BLayout>;
|
||||
|
||||
using AElementWise = identity;
|
||||
using BElementWise = identity;
|
||||
|
||||
static constexpr bool TransposeC = false;
|
||||
static constexpr bool kPadM = true;
|
||||
static constexpr bool kPadN = true;
|
||||
static constexpr bool kPadK = true;
|
||||
static constexpr bool Preshuffle = false;
|
||||
|
||||
static constexpr auto Scheduler = GemmPipelineScheduler::Intrawave;
|
||||
static constexpr index_t NumWaveGroups = 1;
|
||||
|
||||
static constexpr index_t VectorLoadSize = 16;
|
||||
|
||||
// kBlockSize derived from BlockShape
|
||||
static constexpr index_t kBlockSize = BlockShape::BlockSize;
|
||||
|
||||
// Additional traits
|
||||
static constexpr bool DoubleSmemBuffer = true;
|
||||
static constexpr bool UseStructuredSparsity = false;
|
||||
static constexpr bool FixedVectorSize = false;
|
||||
|
||||
struct Traits
|
||||
{
|
||||
static constexpr bool UsePersistentKernel = false;
|
||||
};
|
||||
|
||||
CK_TILE_HOST static const std::string GetName() { return "MHCProblemV4"; }
|
||||
|
||||
// Tile distribution for loading X (input matrix) from global memory
|
||||
// X is [Batch, nC] row-major, we load kM×kK tiles (16×16)
|
||||
// For a 16×16 tile with 64 threads (1 warp):
|
||||
// M: 1 repeat × 1 warp × 16 threads × 1 vector = 16
|
||||
// K: 1 repeat × 1 warp × 4 threads × 4 vector = 16
|
||||
// Total threads: 1 warp × (16×4) = 64 threads ✓
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakeXLoadTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// H0 (M dimension): [repeat=1, warp=1, thread=16, vector=1] = 16
|
||||
// H1 (K dimension): [repeat=1, warp=1, thread=4, vector=4] = 16
|
||||
// P→RH: Warp layout = 1 warp in M × 1 warp in K
|
||||
// Thread layout = 16 threads in M × 4 threads in K = 64 threads/warp
|
||||
// Y→RH: Access order = M_repeat → M_vector → K_repeat → K_vector (vectorized)
|
||||
using XTileDistEncoding = tile_distribution_encoding<
|
||||
sequence<>, // R: No replication
|
||||
tuple<sequence<1, 1, 16, 1>, // H0 (M): repeat=1, warp=1, thread=16, vector=1
|
||||
sequence<1, 1, 4, 4>>, // H1 (K): repeat=1, warp=1, thread=4, vector=4
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor
|
||||
sequence<1, 1, 2, 2>, // Y→RH major
|
||||
sequence<0, 3, 0, 3>>; // Y→RH minor
|
||||
|
||||
return make_static_tile_distribution(XTileDistEncoding{});
|
||||
}
|
||||
|
||||
// Tile distribution for loading Phi (weight matrix) from global memory
|
||||
// Phi is [output_dim, nC] row-major, we load kN×kK tiles (16×16)
|
||||
CK_TILE_HOST_DEVICE static constexpr auto MakePhiLoadTileDistribution()
|
||||
{
|
||||
using namespace ck_tile;
|
||||
|
||||
// Same distribution as X for 16×16 tiles
|
||||
using PhiTileDistEncoding = tile_distribution_encoding<
|
||||
sequence<>, // R: No replication
|
||||
tuple<sequence<1, 1, 16, 1>, // H0 (N): repeat=1, warp=1, thread=16, vector=1
|
||||
sequence<1, 1, 4, 4>>, // H1 (K): repeat=1, warp=1, thread=4, vector=4
|
||||
tuple<sequence<1, 2>, sequence<1, 2>>, // P→RH major
|
||||
tuple<sequence<1, 1>, sequence<2, 2>>, // P→RH minor
|
||||
sequence<1, 1, 2, 2>, // Y→RH major
|
||||
sequence<0, 3, 0, 3>>; // Y→RH minor
|
||||
|
||||
return make_static_tile_distribution(PhiTileDistEncoding{});
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user