mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Add descriptor-based architecture for batched contraction multi-dimensional stride support
This commit is contained in:
@@ -5,6 +5,7 @@
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
#include "ck_tile/ops/batched_contraction/pipeline/batched_contraction_problem.hpp"
|
||||
#include "ck_tile/ops/batched_contraction/utils/tensor_descriptor_utils.hpp"
|
||||
#include "ck_tile/ops/gemm/kernel/universal_gemm_kernel.hpp"
|
||||
|
||||
/**
|
||||
@@ -210,11 +211,28 @@ struct BatchedContractionKernelArgs
|
||||
ck_tile::index_t N_total; ///< Total N dimension: N0 * N1 * ... * N_{NumDimN-1}
|
||||
ck_tile::index_t K_total; ///< Total K dimension: K0 * K1 * ... * K_{NumDimK-1}
|
||||
|
||||
ck_tile::index_t stride_A; ///< Leading dimension stride for tensor A (row-major: K_total)
|
||||
ck_tile::index_t stride_B; ///< Leading dimension stride for tensor B (row-major: K_total)
|
||||
ck_tile::index_t
|
||||
stride_A; ///< Leading dimension stride for tensor A (for backward compatibility)
|
||||
ck_tile::index_t
|
||||
stride_B; ///< Leading dimension stride for tensor B (for backward compatibility)
|
||||
std::array<ck_tile::index_t, NumDTensor>
|
||||
stride_Ds; ///< Leading dimension strides for D tensors (row-major: N_total)
|
||||
ck_tile::index_t stride_E; ///< Leading dimension stride for tensor E (row-major: N_total)
|
||||
stride_Ds; ///< Leading dimension strides for D tensors (for backward compatibility)
|
||||
ck_tile::index_t
|
||||
stride_E; ///< Leading dimension stride for tensor E (for backward compatibility)
|
||||
|
||||
// Tensor descriptors (encode full multi-dimensional stride information)
|
||||
// These are created on host and passed to device (like old CK)
|
||||
using AGridDesc_M_K_ = decltype(TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>::
|
||||
Make_A_GridDescriptor_M_K({}, {}));
|
||||
using BGridDesc_N_K_ = decltype(TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>::
|
||||
Make_B_GridDescriptor_N_K({}, {}));
|
||||
using EGridDesc_M_N_ = decltype(TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>::
|
||||
Make_E_GridDescriptor_M_N({}, {}));
|
||||
|
||||
AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides
|
||||
BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides
|
||||
EGridDesc_M_N_ e_grid_desc_m_n; ///< Tensor descriptor for E[M, N] with actual strides
|
||||
// TODO: Add D descriptors array when needed
|
||||
};
|
||||
|
||||
/// @brief GPU kernel for batched tensor contraction operations.
|
||||
@@ -274,6 +292,14 @@ struct BatchedContractionKernel
|
||||
static constexpr ck_tile::index_t kBlockSize =
|
||||
UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel
|
||||
|
||||
// Tensor descriptor utilities for creating stride-aware descriptors
|
||||
using DescriptorUtils = TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>;
|
||||
|
||||
// Tensor descriptor types (created on host, encode all stride information)
|
||||
using AGridDesc_M_K = decltype(DescriptorUtils::Make_A_GridDescriptor_M_K({}, {}));
|
||||
using BGridDesc_N_K = decltype(DescriptorUtils::Make_B_GridDescriptor_N_K({}, {}));
|
||||
using EGridDesc_M_N = decltype(DescriptorUtils::Make_E_GridDescriptor_M_N({}, {}));
|
||||
|
||||
using KernelArgs =
|
||||
BatchedContractionKernelArgs<NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor>; ///< Kernel
|
||||
///< argument
|
||||
@@ -435,6 +461,15 @@ struct BatchedContractionKernel
|
||||
kargs.K_total *= kargs.K_dims[i];
|
||||
}
|
||||
|
||||
// Create tensor descriptors on host using actual dims and strides
|
||||
kargs.a_grid_desc_m_k =
|
||||
DescriptorUtils::Make_A_GridDescriptor_M_K(host_args.A_dims, host_args.A_strides);
|
||||
kargs.b_grid_desc_n_k =
|
||||
DescriptorUtils::Make_B_GridDescriptor_N_K(host_args.B_dims, host_args.B_strides);
|
||||
kargs.e_grid_desc_m_n =
|
||||
DescriptorUtils::Make_E_GridDescriptor_M_N(host_args.E_dims, host_args.E_strides);
|
||||
|
||||
// Keep simple strides for backward compatibility
|
||||
kargs.stride_A = kargs.K_total;
|
||||
kargs.stride_B = kargs.K_total;
|
||||
kargs.stride_E = kargs.N_total;
|
||||
|
||||
Reference in New Issue
Block a user