mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
Clean up batched contraction: remove legacy paths and finalize docs
This commit is contained in:
@@ -4,8 +4,6 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdlib>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <thread>
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
@@ -61,23 +61,19 @@
|
||||
* Rather than implementing tensor contraction from scratch, this kernel leverages the highly
|
||||
* optimized `UniversalGemmKernel` as its computational backend.
|
||||
*
|
||||
* @subsection current_limitations Current Kernel Limitations
|
||||
* @subsection implementation_features Implementation Features
|
||||
*
|
||||
* **Layout Restrictions:**
|
||||
* - **Row-Major Only**: All tensors must use row-major memory layout
|
||||
* - **Packed Tensors**: Only contiguous/packed tensor layouts supported
|
||||
* - **Hardcoded Strides**: stride_A = K_total, stride_B = K_total, stride_E = N_total
|
||||
* - **D Tensor Layout**: All D tensors must match E tensor layout (stride_Ds = N_total)
|
||||
* **Stride Support:**
|
||||
* - Supports arbitrary multi-dimensional stride patterns
|
||||
* - Handles non-contiguous and padded tensor layouts
|
||||
* - Independent strides for each auxiliary D tensor
|
||||
* - Descriptor-based architecture with vectorization
|
||||
*
|
||||
* **Implementation Constraints:**
|
||||
* - **Fixed Stride Calculation**: Strides are automatically calculated and cannot be customized
|
||||
* - **No Column-Major**: Column-major or custom stride patterns not supported
|
||||
* - **No Strided Access**: Non-contiguous tensor slicing not supported
|
||||
*
|
||||
* **Future Enhancements:**
|
||||
* - Support for arbitrary stride patterns
|
||||
* - Column-major and mixed layout support
|
||||
* - Non-contiguous tensor operation support
|
||||
* **Architecture:**
|
||||
* - Uses TensorDescriptorUtils for stride-aware descriptor creation
|
||||
* - Custom RunGemm implementation with descriptor-based tensor views
|
||||
* - Reuses GemmPipeline and EpiloguePipeline for computation
|
||||
* - Split-K support via UniversalGemmKernel utilities
|
||||
*/
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -379,10 +375,23 @@ struct BatchedContractionKernel
|
||||
TilePartitioner::GridSize(kargs.M_total, kargs.N_total), kargs.G_total, kargs.k_batch);
|
||||
}
|
||||
|
||||
/// @brief Custom RunGemm implementation using tensor descriptors for multi-dimensional
|
||||
/// non-contiguous support
|
||||
/// @details This function creates tensor views from descriptors and runs GEMM pipeline,
|
||||
/// similar to UniversalGemmKernel::RunGemm but with descriptor-based tensor views
|
||||
/// @brief Executes GEMM computation with descriptor-based tensor views for arbitrary stride
|
||||
/// support
|
||||
///
|
||||
/// @details This function performs the core GEMM computation using tensor descriptors to handle
|
||||
/// arbitrary multi-dimensional stride patterns. It creates tensor views from
|
||||
/// pre-computed descriptors (stored in kargs), applies padding, creates tile windows,
|
||||
/// and executes the GemmPipeline and EpiloguePipeline.
|
||||
///
|
||||
/// @param a_ptr Pointer to input tensor A data (after batch and split-K offsets applied)
|
||||
/// @param b_ptr Pointer to input tensor B data (after batch and split-K offsets applied)
|
||||
/// @param ds_ptr Array of pointers to auxiliary D tensor data
|
||||
/// @param e_ptr Pointer to output tensor E data (after batch offset applied)
|
||||
/// @param smem_ptr Pointer to shared memory for tile operations
|
||||
/// @param kargs Kernel arguments containing tensor descriptors and dimension information
|
||||
/// @param k_size Size of K dimension for this split (for split-K support)
|
||||
/// @param i_m Starting M index for this block's tile
|
||||
/// @param i_n Starting N index for this block's tile
|
||||
CK_TILE_DEVICE static void RunGemm(const ADataType* a_ptr,
|
||||
const BDataType* b_ptr,
|
||||
const std::array<const void*, NumDTensor>& ds_ptr,
|
||||
@@ -393,36 +402,13 @@ struct BatchedContractionKernel
|
||||
const index_t i_m,
|
||||
const index_t i_n)
|
||||
{
|
||||
#if 1 // DESCRIPTOR PATH: Full multi-dimensional stride support
|
||||
// Create tensor views from descriptors
|
||||
// Create tensor views from descriptors (supports arbitrary stride patterns)
|
||||
auto a_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(a_ptr, kargs.a_grid_desc_m_k);
|
||||
auto b_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(b_ptr, kargs.b_grid_desc_n_k);
|
||||
auto e_tensor_view =
|
||||
make_tensor_view<address_space_enum::global>(e_ptr, kargs.e_grid_desc_m_n);
|
||||
#else // NAIVE PATH: Simple views assuming contiguous (for performance testing)
|
||||
auto a_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
a_ptr,
|
||||
make_tuple(kargs.M_total, k_size),
|
||||
make_tuple(kargs.K_total, 1),
|
||||
number<GemmPipeline::GetVectorSizeA()>{},
|
||||
number<1>{});
|
||||
|
||||
auto b_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
b_ptr,
|
||||
make_tuple(kargs.N_total, k_size),
|
||||
make_tuple(kargs.K_total, 1),
|
||||
number<GemmPipeline::GetVectorSizeB()>{},
|
||||
number<1>{});
|
||||
|
||||
auto e_tensor_view = make_naive_tensor_view<address_space_enum::global>(
|
||||
e_ptr,
|
||||
make_tuple(kargs.M_total, kargs.N_total),
|
||||
make_tuple(kargs.N_total, 1),
|
||||
number<EpiloguePipeline::GetVectorSizeC()>{},
|
||||
number<1>{});
|
||||
#endif
|
||||
|
||||
// Pad views for boundary handling and optimization (like UniversalGemmKernel)
|
||||
auto a_pad_view = pad_tensor_view(
|
||||
|
||||
@@ -13,8 +13,8 @@
|
||||
* dimensions for GEMM operations. These functions transform multi-dimensional tensors into
|
||||
* 2D matrix descriptors by removing batch dimensions and flattening the remaining dimensions.
|
||||
*
|
||||
* These utilities are currently not used in the main batched contraction kernel but are preserved
|
||||
* for future implementations that may require explicit tensor descriptor creation.
|
||||
* These utilities are used by BatchedContractionKernel to create stride-aware descriptors
|
||||
* that support arbitrary multi-dimensional non-contiguous tensor layouts.
|
||||
*/
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
Reference in New Issue
Block a user