Enable vectorization in descriptor-based batched contraction. Add pad_tensor_view to local RunGemm

This commit is contained in:
Mohsen Saffari
2025-10-21 14:29:49 +00:00
parent bbfe4501fa
commit 6144f5c490
2 changed files with 109 additions and 38 deletions

View File

@@ -185,7 +185,10 @@ template <ck_tile::index_t NumDimG,
ck_tile::index_t NumDimM,
ck_tile::index_t NumDimN,
ck_tile::index_t NumDimK,
ck_tile::index_t NumDTensor = 0>
ck_tile::index_t NumDTensor = 0,
ck_tile::index_t VectorSizeA = 1,
ck_tile::index_t VectorSizeB = 1,
ck_tile::index_t VectorSizeE = 1>
struct BatchedContractionKernelArgs
{
const void* a_ptr; ///< Pointer to input tensor A
@@ -220,14 +223,31 @@ struct BatchedContractionKernelArgs
ck_tile::index_t
stride_E; ///< Leading dimension stride for tensor E (for backward compatibility)
// Tensor descriptors (encode full multi-dimensional stride information)
// These are created on host and passed to device (like old CK)
using AGridDesc_M_K_ = decltype(TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>::
Make_A_GridDescriptor_M_K({}, {}));
using BGridDesc_N_K_ = decltype(TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>::
Make_B_GridDescriptor_N_K({}, {}));
using EGridDesc_M_N_ = decltype(TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>::
Make_E_GridDescriptor_M_N({}, {}));
// Tensor descriptors (encode full multi-dimensional stride information with vectorization)
using AGridDesc_M_K_ =
decltype(TensorDescriptorUtils<NumDimG,
NumDimM,
NumDimN,
NumDimK,
VectorSizeA,
VectorSizeB,
VectorSizeE>::Make_A_GridDescriptor_M_K({}, {}));
using BGridDesc_N_K_ =
decltype(TensorDescriptorUtils<NumDimG,
NumDimM,
NumDimN,
NumDimK,
VectorSizeA,
VectorSizeB,
VectorSizeE>::Make_B_GridDescriptor_N_K({}, {}));
using EGridDesc_M_N_ =
decltype(TensorDescriptorUtils<NumDimG,
NumDimM,
NumDimN,
NumDimK,
VectorSizeA,
VectorSizeB,
VectorSizeE>::Make_E_GridDescriptor_M_N({}, {}));
AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides
BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides
@@ -293,18 +313,24 @@ struct BatchedContractionKernel
static constexpr ck_tile::index_t kBlockSize =
UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel
// Tensor descriptor utilities for creating stride-aware descriptors
using DescriptorUtils = TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>;
// Tensor descriptor utilities with vectorization support
using DescriptorUtils = TensorDescriptorUtils<NumDimG,
NumDimM,
NumDimN,
NumDimK,
GemmPipeline::GetVectorSizeA(),
GemmPipeline::GetVectorSizeB(),
EpiloguePipeline::GetVectorSizeC()>;
// Tensor descriptor types (created on host, encode all stride information)
using AGridDesc_M_K = decltype(DescriptorUtils::Make_A_GridDescriptor_M_K({}, {}));
using BGridDesc_N_K = decltype(DescriptorUtils::Make_B_GridDescriptor_N_K({}, {}));
using EGridDesc_M_N = decltype(DescriptorUtils::Make_E_GridDescriptor_M_N({}, {}));
using KernelArgs =
BatchedContractionKernelArgs<NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor>; ///< Kernel
///< argument
///< structure
// Kernel arguments with vectorization support
using KernelArgs = BatchedContractionKernelArgs<NumDimG,
NumDimM,
NumDimN,
NumDimK,
NumDTensor,
GemmPipeline::GetVectorSizeA(),
GemmPipeline::GetVectorSizeB(),
EpiloguePipeline::GetVectorSizeC()>;
/// @brief Returns the kernel name for debugging and profiling purposes.
/// @return Constant string identifier for this kernel
@@ -363,36 +389,76 @@ struct BatchedContractionKernel
EDataType* e_ptr,
void* smem_ptr,
const KernelArgs& kargs,
const index_t k_size,
const index_t i_m,
const index_t i_n)
{
// Create tensor views from descriptors (handles multi-dimensional strides)
#if 1 // DESCRIPTOR PATH: Full multi-dimensional stride support
// Create tensor views from descriptors
auto a_tensor_view =
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
auto b_tensor_view =
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
auto e_tensor_view =
make_tensor_view<address_space_enum::global>(e_ptr, kargs.e_grid_desc_m_n);
#else // NAIVE PATH: Simple views assuming contiguous (for performance testing)
auto a_tensor_view = make_naive_tensor_view<address_space_enum::global>(
a_ptr,
make_tuple(kargs.M_total, k_size),
make_tuple(kargs.K_total, 1),
number<GemmPipeline::GetVectorSizeA()>{},
number<1>{});
// Create tile windows for this block's work
auto a_block_window = make_tile_window(
auto b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
b_ptr,
make_tuple(kargs.N_total, k_size),
make_tuple(kargs.K_total, 1),
number<GemmPipeline::GetVectorSizeB()>{},
number<1>{});
auto e_tensor_view = make_naive_tensor_view<address_space_enum::global>(
e_ptr,
make_tuple(kargs.M_total, kargs.N_total),
make_tuple(kargs.N_total, 1),
number<EpiloguePipeline::GetVectorSizeC()>{},
number<1>{});
#endif
// Pad views for boundary handling and optimization (like UniversalGemmKernel)
auto a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
auto b_pad_view = pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
sequence<false, GemmPipeline::kPadK>{});
auto e_pad_view = pad_tensor_view(
e_tensor_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
sequence<false, GemmPipeline::kPadN>{});
// Create tile windows from PADDED views
auto a_block_window = make_tile_window(
a_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_m, 0});
auto b_block_window = make_tile_window(
b_tensor_view,
b_pad_view,
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
{i_n, 0});
auto e_block_window = make_tile_window(
e_tensor_view,
e_pad_view,
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
{i_m, i_n});
// Calculate number of K loops
const index_t num_loop =
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.K_total));
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size));
// Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows)
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
@@ -631,8 +697,10 @@ struct BatchedContractionKernel
i_m,
i_n);
#else // NEW PATH: Descriptor-based RunGemm
// custom descriptor-based RunGemm with full multi-dimensional stride support
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, i_m, i_n);
// Custom descriptor-based RunGemm with full multi-dimensional stride support
// For now, use K_total (split-K to be properly implemented later)
const index_t k_size = kargs.K_total;
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, k_size, i_m, i_n);
#endif
}
};

View File

@@ -30,7 +30,10 @@ namespace ck_tile {
template <ck_tile::index_t NumDimG,
ck_tile::index_t NumDimM,
ck_tile::index_t NumDimN,
ck_tile::index_t NumDimK>
ck_tile::index_t NumDimK,
ck_tile::index_t VectorSizeA,
ck_tile::index_t VectorSizeB,
ck_tile::index_t VectorSizeE>
struct TensorDescriptorUtils
{
/// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed.
@@ -62,9 +65,9 @@ struct TensorDescriptorUtils
const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids);
const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor
const auto A_grid_desc_Ms_Ks =
ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K);
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Descriptor with vector size
const auto A_grid_desc_Ms_Ks = ck_tile::make_naive_tensor_descriptor(
A_dims_M_K, A_strides_M_K, number<VectorSizeA>{}, number<1>{});
// transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total
// = K0 * K1 * K2 * ...]
@@ -106,9 +109,9 @@ struct TensorDescriptorUtils
const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids);
const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor
const auto B_grid_desc_Ns_Ks =
ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K);
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Descriptor with vector size
const auto B_grid_desc_Ns_Ks = ck_tile::make_naive_tensor_descriptor(
B_dims_N_K, B_strides_N_K, number<VectorSizeB>{}, number<1>{});
// transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total
// = K0 * K1 * K2 * ...]
@@ -150,9 +153,9 @@ struct TensorDescriptorUtils
const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids);
const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor
const auto E_grid_desc_Ms_Ns =
ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N);
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Descriptor with vector size
const auto E_grid_desc_Ms_Ns = ck_tile::make_naive_tensor_descriptor(
E_dims_M_N, E_strides_M_N, number<VectorSizeE>{}, number<1>{});
// transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... ,
// N_total = N0 * N1 * N2 * ...]