mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Enable vectorization in descriptor-based batched contraction. Add pad_tensor_view to local RunGemm
This commit is contained in:
@@ -185,7 +185,10 @@ template <ck_tile::index_t NumDimG,
|
||||
ck_tile::index_t NumDimM,
|
||||
ck_tile::index_t NumDimN,
|
||||
ck_tile::index_t NumDimK,
|
||||
ck_tile::index_t NumDTensor = 0>
|
||||
ck_tile::index_t NumDTensor = 0,
|
||||
ck_tile::index_t VectorSizeA = 1,
|
||||
ck_tile::index_t VectorSizeB = 1,
|
||||
ck_tile::index_t VectorSizeE = 1>
|
||||
struct BatchedContractionKernelArgs
|
||||
{
|
||||
const void* a_ptr; ///< Pointer to input tensor A
|
||||
@@ -220,14 +223,31 @@ struct BatchedContractionKernelArgs
|
||||
ck_tile::index_t
|
||||
stride_E; ///< Leading dimension stride for tensor E (for backward compatibility)
|
||||
|
||||
// Tensor descriptors (encode full multi-dimensional stride information)
|
||||
// These are created on host and passed to device (like old CK)
|
||||
using AGridDesc_M_K_ = decltype(TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>::
|
||||
Make_A_GridDescriptor_M_K({}, {}));
|
||||
using BGridDesc_N_K_ = decltype(TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>::
|
||||
Make_B_GridDescriptor_N_K({}, {}));
|
||||
using EGridDesc_M_N_ = decltype(TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>::
|
||||
Make_E_GridDescriptor_M_N({}, {}));
|
||||
// Tensor descriptors (encode full multi-dimensional stride information with vectorization)
|
||||
using AGridDesc_M_K_ =
|
||||
decltype(TensorDescriptorUtils<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeE>::Make_A_GridDescriptor_M_K({}, {}));
|
||||
using BGridDesc_N_K_ =
|
||||
decltype(TensorDescriptorUtils<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeE>::Make_B_GridDescriptor_N_K({}, {}));
|
||||
using EGridDesc_M_N_ =
|
||||
decltype(TensorDescriptorUtils<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
VectorSizeA,
|
||||
VectorSizeB,
|
||||
VectorSizeE>::Make_E_GridDescriptor_M_N({}, {}));
|
||||
|
||||
AGridDesc_M_K_ a_grid_desc_m_k; ///< Tensor descriptor for A[M, K] with actual strides
|
||||
BGridDesc_N_K_ b_grid_desc_n_k; ///< Tensor descriptor for B[N, K] with actual strides
|
||||
@@ -293,18 +313,24 @@ struct BatchedContractionKernel
|
||||
static constexpr ck_tile::index_t kBlockSize =
|
||||
UniversalGemmKernel::kBlockSize; ///< GPU block size inherited from GEMM kernel
|
||||
|
||||
// Tensor descriptor utilities for creating stride-aware descriptors
|
||||
using DescriptorUtils = TensorDescriptorUtils<NumDimG, NumDimM, NumDimN, NumDimK>;
|
||||
// Tensor descriptor utilities with vectorization support
|
||||
using DescriptorUtils = TensorDescriptorUtils<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
GemmPipeline::GetVectorSizeA(),
|
||||
GemmPipeline::GetVectorSizeB(),
|
||||
EpiloguePipeline::GetVectorSizeC()>;
|
||||
|
||||
// Tensor descriptor types (created on host, encode all stride information)
|
||||
using AGridDesc_M_K = decltype(DescriptorUtils::Make_A_GridDescriptor_M_K({}, {}));
|
||||
using BGridDesc_N_K = decltype(DescriptorUtils::Make_B_GridDescriptor_N_K({}, {}));
|
||||
using EGridDesc_M_N = decltype(DescriptorUtils::Make_E_GridDescriptor_M_N({}, {}));
|
||||
|
||||
using KernelArgs =
|
||||
BatchedContractionKernelArgs<NumDimG, NumDimM, NumDimN, NumDimK, NumDTensor>; ///< Kernel
|
||||
///< argument
|
||||
///< structure
|
||||
// Kernel arguments with vectorization support
|
||||
using KernelArgs = BatchedContractionKernelArgs<NumDimG,
|
||||
NumDimM,
|
||||
NumDimN,
|
||||
NumDimK,
|
||||
NumDTensor,
|
||||
GemmPipeline::GetVectorSizeA(),
|
||||
GemmPipeline::GetVectorSizeB(),
|
||||
EpiloguePipeline::GetVectorSizeC()>;
|
||||
|
||||
/// @brief Returns the kernel name for debugging and profiling purposes.
|
||||
/// @return Constant string identifier for this kernel
|
||||
@@ -363,36 +389,76 @@ struct BatchedContractionKernel
|
||||
EDataType* e_ptr,
|
||||
void* smem_ptr,
|
||||
const KernelArgs& kargs,
|
||||
const index_t k_size,
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
// Create tensor views from descriptors (handles multi-dimensional strides)
|
||||
#if 1 // DESCRIPTOR PATH: Full multi-dimensional stride support
|
||||
// Create tensor views from descriptors
|
||||
auto a_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
|
||||
auto b_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
|
||||
auto e_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(e_ptr, kargs.e_grid_desc_m_n);
|
||||
#else // NAIVE PATH: Simple views assuming contiguous (for performance testing)
|
||||
auto a_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M_total, k_size),
|
||||
make_tuple(kargs.K_total, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
|
||||
// Create tile windows for this block's work
|
||||
auto a_block_window = make_tile_window(
|
||||
auto b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N_total, k_size),
|
||||
make_tuple(kargs.K_total, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
|
||||
auto e_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M_total, kargs.N_total),
|
||||
make_tuple(kargs.N_total, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
#endif
|
||||
|
||||
// Pad views for boundary handling and optimization (like UniversalGemmKernel)
|
||||
auto a_pad_view = pad_tensor_view(
|
||||
a_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
|
||||
auto b_pad_view = pad_tensor_view(
|
||||
b_tensor_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadK>{});
|
||||
|
||||
auto e_pad_view = pad_tensor_view(
|
||||
e_tensor_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
sequence<false, GemmPipeline::kPadN>{});
|
||||
|
||||
// Create tile windows from PADDED views
|
||||
auto a_block_window = make_tile_window(
|
||||
a_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_m, 0});
|
||||
|
||||
auto b_block_window = make_tile_window(
|
||||
b_tensor_view,
|
||||
b_pad_view,
|
||||
make_tuple(number<TilePartitioner::NPerBlock>{}, number<TilePartitioner::KPerBlock>{}),
|
||||
{i_n, 0});
|
||||
|
||||
auto e_block_window = make_tile_window(
|
||||
e_tensor_view,
|
||||
e_pad_view,
|
||||
make_tuple(number<TilePartitioner::MPerBlock>{}, number<TilePartitioner::NPerBlock>{}),
|
||||
{i_m, i_n});
|
||||
|
||||
// Calculate number of K loops
|
||||
const index_t num_loop =
|
||||
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(kargs.K_total));
|
||||
__builtin_amdgcn_readfirstlane(TilePartitioner::GetLoopNum(k_size));
|
||||
|
||||
// Run GEMM Pipeline (same as UniversalGemmKernel, but with descriptor-based windows)
|
||||
using AElementWise = remove_cvref_t<typename GemmPipeline::AElementWise>;
|
||||
@@ -631,8 +697,10 @@ struct BatchedContractionKernel
|
||||
i_m,
|
||||
i_n);
|
||||
#else // NEW PATH: Descriptor-based RunGemm
|
||||
// custom descriptor-based RunGemm with full multi-dimensional stride support
|
||||
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, i_m, i_n);
|
||||
// Custom descriptor-based RunGemm with full multi-dimensional stride support
|
||||
// For now, use K_total (split-K to be properly implemented later)
|
||||
const index_t k_size = kargs.K_total;
|
||||
RunGemm(a_ptr, b_ptr, ds_batch_ptr, e_ptr, smem_ptr, kargs, k_size, i_m, i_n);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
@@ -30,7 +30,10 @@ namespace ck_tile {
|
||||
template <ck_tile::index_t NumDimG,
|
||||
ck_tile::index_t NumDimM,
|
||||
ck_tile::index_t NumDimN,
|
||||
ck_tile::index_t NumDimK>
|
||||
ck_tile::index_t NumDimK,
|
||||
ck_tile::index_t VectorSizeA,
|
||||
ck_tile::index_t VectorSizeB,
|
||||
ck_tile::index_t VectorSizeE>
|
||||
struct TensorDescriptorUtils
|
||||
{
|
||||
/// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed.
|
||||
@@ -62,9 +65,9 @@ struct TensorDescriptorUtils
|
||||
const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids);
|
||||
const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids);
|
||||
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor
|
||||
const auto A_grid_desc_Ms_Ks =
|
||||
ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K);
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Descriptor with vector size
|
||||
const auto A_grid_desc_Ms_Ks = ck_tile::make_naive_tensor_descriptor(
|
||||
A_dims_M_K, A_strides_M_K, number<VectorSizeA>{}, number<1>{});
|
||||
|
||||
// transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total
|
||||
// = K0 * K1 * K2 * ...]
|
||||
@@ -106,9 +109,9 @@ struct TensorDescriptorUtils
|
||||
const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids);
|
||||
const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids);
|
||||
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor
|
||||
const auto B_grid_desc_Ns_Ks =
|
||||
ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K);
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Descriptor with vector size
|
||||
const auto B_grid_desc_Ns_Ks = ck_tile::make_naive_tensor_descriptor(
|
||||
B_dims_N_K, B_strides_N_K, number<VectorSizeB>{}, number<1>{});
|
||||
|
||||
// transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total
|
||||
// = K0 * K1 * K2 * ...]
|
||||
@@ -150,9 +153,9 @@ struct TensorDescriptorUtils
|
||||
const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids);
|
||||
const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids);
|
||||
|
||||
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor
|
||||
const auto E_grid_desc_Ms_Ns =
|
||||
ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N);
|
||||
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Descriptor with vector size
|
||||
const auto E_grid_desc_Ms_Ns = ck_tile::make_naive_tensor_descriptor(
|
||||
E_dims_M_N, E_strides_M_N, number<VectorSizeE>{}, number<1>{});
|
||||
|
||||
// transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... ,
|
||||
// N_total = N0 * N1 * N2 * ...]
|
||||
|
||||
Reference in New Issue
Block a user