Add descriptor-based architecture for batched contraction multi-dimensional stride support

This commit is contained in:
Mohsen Saffari
2025-10-20 10:30:23 +00:00
parent fec833263c
commit 2ecb0bfb3e

View File

@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp"
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
/**
@@ -210,11 +211,28 @@ struct BatchedContractionKernelArgs
ck_tile::index_t N_total; ///< Total N dimension: N0 * N1 * ... * N_{NumDimN-1}
ck_tile::index_t K_total; ///< Total K dimension: K0 * K1 * ... * K_{NumDimK-1}
ck_tile::index_t stride_A; ///< Leading dimension stride for tensor A (row-major: K_total)
ck_tile::index_t stride_B; ///< Leading dimension stride for tensor B (row-major: K_total)
ck_tile::index_t
stride_A; ///< Leading dimension stride for tensor A (for backward compatibility)
ck_tile::index_t
stride_B; ///< Leading dimension stride for tensor B (for backward compatibility)
std::array<ck_tile::index_t, NumDTensor>
stride_Ds; ///< Leading dimension strides for D tensors (row-major: N_total)
ck_tile::index_t stride_E; ///< Leading dimension stride for tensor E (row-major: N_total)
stride_Ds; ///< Leading dimension strides for D tensors (for backward compatibility)
ck_tile::index_t
stride_E; ///< Leading dimension stride for tensor E (for backward compatibility)
// Tensor descriptors (encode full multi-dimensional stride information)
// These are created on host and passed to device (like old CK)
using AGridDesc_M_K_ = decltype(TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>::
Make_A_GridDescriptor_M_K({}, {}));
using BGridDesc_N_K_ = decltype(TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>::
Make_B_GridDescriptor_N_K({}, {}));
using EGridDesc_M_N_ = decltype(TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>::
Make_E_GridDescriptor_M_N({}, {}));
AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides
BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides
EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides
// TODO: Add D descriptors array when needed
};
/// @brief GPU kernel for batched tensor contraction operations.
@@ -274,6 +292,14 @@ struct BatchedContractionKernel
static constexpr ck_tile::index_t kBlockSize =
UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel
// Tensor descriptor utilities for creating stride-aware descriptors
using DescriptorUtils = TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>;
// Tensor descriptor types (created on host, encode all stride information)
using AGridDesc_M_K = decltype(DescriptorUtils::Make_A_GridDescriptor_M_K({}, {}));
using BGridDesc_N_K = decltype(DescriptorUtils::Make_B_GridDescriptor_N_K({}, {}));
using EGridDesc_M_N = decltype(DescriptorUtils::Make_E_GridDescriptor_M_N({}, {}));
using KernelArgs =
BatchedContractionKernelArgs<NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor>; ///< Kernel
///< argument
@@ -435,6 +461,15 @@ struct BatchedContractionKernel
kargs.K_total *= kargs.K_dims[i];
}
// Create tensor descriptors on host using actual dims and strides
kargs.a_grid_desc_m_k =
DescriptorUtils::Make_A_GridDescriptor_M_K(host_args.A_dims, host_args.A_strides);
kargs.b_grid_desc_n_k =
DescriptorUtils::Make_B_GridDescriptor_N_K(host_args.B_dims, host_args.B_strides);
kargs.e_grid_desc_m_n =
DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides);
// Keep simple strides for backward compatibility
kargs.stride_A = kargs.K_total;
kargs.stride_B = kargs.K_total;
kargs.stride_E = kargs.N_total;