mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 13:11:25 +00:00
[CK Tile] batched contraction kernel generalizing (#3126)
* Add help for example * Refactore the compute reference batched contraction to manage stride-aware calculation and some code cleanings * Add stride-aware reference for batched contraction with independent D tensor layouts * Add -num_d argument for runtime D tensor count selection in batched contraction * Add stride vector arguments in example code for testing non-contiguous batched contraction inputs * Add descriptor-based architecture for batched contraction multi-dimensional stride support * Add multi-dimensional non-contiguous stride support to batched contraction, num_d = 0 * Add complete multi-dimensional stride support via descriptors * Enable vectorization in descriptor-based batched contraction. Add pad_tensor_view to local RunGemm * Clean up batched contraction: remove old UniversalGemmKernel path * Clean up batched contraction: remove legacy paths and finalize docs * Optimize batched contraction example: pass dimension sizes not vectors * correct the reference calculation, unsigned int to int * Fix batched_contraction C++17 build errors for gfx90a CI
This commit is contained in:
@@ -13,8 +13,8 @@
|
||||
* dimensions for GEMM operations. These functions transform multi-dimensional tensors into
|
||||
* 2D matrix descriptors by removing batch dimensions and flattening the remaining dimensions.
|
||||
*
|
||||
* These utilities are currently not used in the main batched contraction kernel but are preserved
|
||||
* for future implementations that may require explicit tensor descriptor creation.
|
||||
* These utilities are used by BatchedContractionKernel to create stride-aware descriptors
|
||||
* that support arbitrary multi-dimensional non-contiguous tensor layouts.
|
||||
*/
|
||||
|
||||
namespace ck_tile {
|
||||
@@ -30,7 +30,10 @@ namespace ck_tile {
|
||||
template <ck_tile::index_t NumDimG,
|
||||
ck_tile::index_t NumDimM,
|
||||
ck_tile::index_t NumDimN,
|
||||
ck_tile::index_t NumDimK>
|
||||
ck_tile::index_t NumDimK,
|
||||
ck_tile::index_t VectorSizeA,
|
||||
ck_tile::index_t VectorSizeB,
|
||||
ck_tile::index_t VectorSizeE>
|
||||
struct TensorDescriptorUtils
|
||||
{
|
||||
/// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed.
|
||||
@@ -62,9 +65,9 @@ struct TensorDescriptorUtils
|
||||
const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids);
|
||||
const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids);
|
||||
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor
|
||||
const auto A_grid_desc_Ms_Ks =
|
||||
ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K);
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Descriptor with vector size
|
||||
const auto A_grid_desc_Ms_Ks = ck_tile::make_naive_tensor_descriptor(
|
||||
A_dims_M_K, A_strides_M_K, number<VectorSizeA>{}, number<1>{});
|
||||
|
||||
// transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total
|
||||
// = K0 * K1 * K2 * ...]
|
||||
@@ -106,9 +109,9 @@ struct TensorDescriptorUtils
|
||||
const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids);
|
||||
const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids);
|
||||
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor
|
||||
const auto B_grid_desc_Ns_Ks =
|
||||
ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K);
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Descriptor with vector size
|
||||
const auto B_grid_desc_Ns_Ks = ck_tile::make_naive_tensor_descriptor(
|
||||
B_dims_N_K, B_strides_N_K, number<VectorSizeB>{}, number<1>{});
|
||||
|
||||
// transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total
|
||||
// = K0 * K1 * K2 * ...]
|
||||
@@ -150,9 +153,9 @@ struct TensorDescriptorUtils
|
||||
const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids);
|
||||
const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids);
|
||||
|
||||
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor
|
||||
const auto E_grid_desc_Ms_Ns =
|
||||
ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N);
|
||||
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Descriptor with vector size
|
||||
const auto E_grid_desc_Ms_Ns = ck_tile::make_naive_tensor_descriptor(
|
||||
E_dims_M_N, E_strides_M_N, number<VectorSizeE>{}, number<1>{});
|
||||
|
||||
// transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... ,
|
||||
// N_total = N0 * N1 * N2 * ...]
|
||||
|
||||
Reference in New Issue
Block a user