mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-02 12:41:26 +00:00
[CK Tile] contraction multi d - kernel & example (#2901)
* Initial commit. create batched_contraction_kernel file * initial problem definition * implement initial example to launch kernel * add universal gemm to contraction. initial phase * complete implementation for special case all Dims are 1 and no Ds * clean code * initial changes to support multi dimensional G * more progress in implementing multiple G * tmp commit * manage dynamic NumDimG in kernel * improving example for multi M,N,K,G handling. start generalizing kernel. it is a temporary commit * implement the example for general Multi dimension G M N K and test different reference calculation algorithms * 2 functions for reference using multi dimensional and flat indexing * clean the code for muti dimentional G, M, N, K contraction and add some logs * Add Make descriptor function in kernel for merging Ms, Ns, Ks for A, B, E * some cleaning on kernel * clean the code for calculating the offsets from flatten batch number * Start adding MultiD support to kernel and example * more changes to manage multi D in kernel and example * manage passing multi d to kernel and testing. * complete multi D support in kernel. modify example code to support it * Correct algorithm to calc the correct offset values for D tensor batches and some code cleaning * Minor fix * Generalize example code for variable NumD tensors and apply cleanup based on review feedback * Refactored code and addressed review feedback * refactoring, cleaning, add documents, in kernel side and example codes * Optimize batch offset calculation in kernel * Inline CalculateBatchOffset in batched contraction kernel, update CHANGELOG.md --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,169 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
/**
|
||||
* @file tensor_descriptor_utils.hpp
|
||||
* @brief Utility functions for creating tensor descriptors in batched contraction operations
|
||||
*
|
||||
* @details This file contains utility functions for creating tensor descriptors with flattened
|
||||
* dimensions for GEMM operations. These functions transform multi-dimensional tensors into
|
||||
* 2D matrix descriptors by removing batch dimensions and flattening the remaining dimensions.
|
||||
*
|
||||
* These utilities are currently not used in the main batched contraction kernel but are preserved
|
||||
* for future implementations that may require explicit tensor descriptor creation.
|
||||
*/
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
/**
|
||||
* @brief Utility class for creating tensor descriptors in batched contraction operations
|
||||
*
|
||||
* @tparam NumDimG Number of batch dimensions
|
||||
* @tparam NumDimM Number of M (output row) dimensions
|
||||
* @tparam NumDimN Number of N (output column) dimensions
|
||||
* @tparam NumDimK Number of K (contraction) dimensions
|
||||
*/
|
||||
template <ck_tile::index_t NumDimG,
|
||||
ck_tile::index_t NumDimM,
|
||||
ck_tile::index_t NumDimN,
|
||||
ck_tile::index_t NumDimK>
|
||||
struct TensorDescriptorUtils
|
||||
{
|
||||
/// @brief Creates a tensor descriptor for input tensor A with batch dimensions removed.
|
||||
/// @param A_dims Dimension vector for tensor A: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
/// @param A_strides Stride vector for tensor A: [G0, G1, ..., M0, M1, M2, ..., K0, K1, K2, ...]
|
||||
/// @return Flattened tensor descriptor: [M_total, K_total] for GEMM computation
|
||||
/// @details Removes batch dimensions and flattens M and K dimensions for efficient GEMM
|
||||
/// execution
|
||||
CK_TILE_HOST static constexpr auto
|
||||
Make_A_GridDescriptor_M_K(const std::vector<ck_tile::index_t>& A_dims = {},
|
||||
const std::vector<ck_tile::index_t>& A_strides = {})
|
||||
{
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, number<end - start>{});
|
||||
};
|
||||
|
||||
// Remove G Dimensions
|
||||
const auto A_dims_M_K =
|
||||
to_tuple(A_dims, number<NumDimG>{}, number<NumDimG + NumDimM + NumDimK>{});
|
||||
const auto A_strides_M_K =
|
||||
to_tuple(A_strides, number<NumDimG>{}, number<NumDimG + NumDimM + NumDimK>{});
|
||||
|
||||
// dimension Ids for M and K
|
||||
constexpr auto A_dims_M_ids = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
constexpr auto A_dims_K_ids =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimK, 1>::type{};
|
||||
|
||||
// Dimensions for M [M0, M1, ...] and K [K0, K1, ...]
|
||||
const auto dims_M = get_container_subset(A_dims_M_K, A_dims_M_ids);
|
||||
const auto dims_K = get_container_subset(A_dims_M_K, A_dims_K_ids);
|
||||
|
||||
// naive tensor A[M0, M1, M2, ..., K0, K1, K2...] Discriptor
|
||||
const auto A_grid_desc_Ms_Ks =
|
||||
ck_tile::make_naive_tensor_descriptor(A_dims_M_K, A_strides_M_K);
|
||||
|
||||
// transformed tensor to flatten M and K dimensions [M_total = M0 * M1 * M2 * ... , K_total
|
||||
// = K0 * K1 * K2 * ...]
|
||||
const auto A_grid_desc_Mflat_Kflat = ck_tile::transform_tensor_descriptor(
|
||||
A_grid_desc_Ms_Ks,
|
||||
make_tuple(make_merge_transform(dims_M), make_merge_transform(dims_K)),
|
||||
make_tuple(A_dims_M_ids, A_dims_K_ids),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return A_grid_desc_Mflat_Kflat;
|
||||
}
|
||||
|
||||
/// @brief Creates a tensor descriptor for input tensor B with batch dimensions removed.
|
||||
/// @param B_dims Dimension vector for tensor B: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
/// @param B_strides Stride vector for tensor B: [G0, G1, ..., N0, N1, N2, ..., K0, K1, K2, ...]
|
||||
/// @return Flattened tensor descriptor: [N_total, K_total] for GEMM computation
|
||||
/// @details Removes batch dimensions and flattens N and K dimensions for efficient GEMM
|
||||
/// execution
|
||||
CK_TILE_HOST static constexpr auto
|
||||
Make_B_GridDescriptor_N_K(const std::vector<ck_tile::index_t>& B_dims = {},
|
||||
const std::vector<ck_tile::index_t>& B_strides = {})
|
||||
{
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, number<end - start>{});
|
||||
};
|
||||
|
||||
// Remove G Dimensions
|
||||
const auto B_dims_N_K =
|
||||
to_tuple(B_dims, number<NumDimG>{}, number<NumDimG + NumDimN + NumDimK>{});
|
||||
const auto B_strides_N_K =
|
||||
to_tuple(B_strides, number<NumDimG>{}, number<NumDimG + NumDimN + NumDimK>{});
|
||||
|
||||
// dimension Ids for N and K
|
||||
constexpr auto B_dims_N_ids = typename arithmetic_sequence_gen<0, NumDimN, 1>::type{};
|
||||
constexpr auto B_dims_K_ids =
|
||||
typename arithmetic_sequence_gen<NumDimN, NumDimN + NumDimK, 1>::type{};
|
||||
|
||||
// Dimensions for N [N0, N1, ...] and K [K0, K1, ...]
|
||||
const auto dims_N = get_container_subset(B_dims_N_K, B_dims_N_ids);
|
||||
const auto dims_K = get_container_subset(B_dims_N_K, B_dims_K_ids);
|
||||
|
||||
// naive tensor B[N0, N1, N2, ..., K0, K1, K2...] Discriptor
|
||||
const auto B_grid_desc_Ns_Ks =
|
||||
ck_tile::make_naive_tensor_descriptor(B_dims_N_K, B_strides_N_K);
|
||||
|
||||
// transformed tensor to flatten N and K dimensions [N_total = N0 * N1 * N2 * ... , K_total
|
||||
// = K0 * K1 * K2 * ...]
|
||||
const auto B_grid_desc_Nflat_Kflat = ck_tile::transform_tensor_descriptor(
|
||||
B_grid_desc_Ns_Ks,
|
||||
make_tuple(make_merge_transform(dims_N), make_merge_transform(dims_K)),
|
||||
make_tuple(B_dims_N_ids, B_dims_K_ids),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return B_grid_desc_Nflat_Kflat;
|
||||
}
|
||||
|
||||
/// @brief Creates a tensor descriptor for output tensor E with batch dimensions removed.
|
||||
/// @param E_dims Dimension vector for tensor E: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
/// @param E_strides Stride vector for tensor E: [G0, G1, ..., M0, M1, M2, ..., N0, N1, N2, ...]
|
||||
/// @return Flattened tensor descriptor: [M_total, N_total] for GEMM computation
|
||||
/// @details Removes batch dimensions and flattens M and N dimensions for efficient GEMM
|
||||
/// execution
|
||||
CK_TILE_HOST static constexpr auto
|
||||
Make_E_GridDescriptor_M_N(const std::vector<ck_tile::index_t>& E_dims = {},
|
||||
const std::vector<ck_tile::index_t>& E_strides = {})
|
||||
{
|
||||
const auto to_tuple = [&](auto& vec, auto start, auto end) {
|
||||
return generate_tuple([&](auto i) { return vec[start + i]; }, number<end - start>{});
|
||||
};
|
||||
|
||||
// Remove G dimensions
|
||||
const auto E_dims_M_N =
|
||||
to_tuple(E_dims, number<NumDimG>{}, number<NumDimG + NumDimM + NumDimN>{});
|
||||
const auto E_strides_M_N =
|
||||
to_tuple(E_strides, number<NumDimG>{}, number<NumDimG + NumDimM + NumDimN>{});
|
||||
|
||||
// dimension Ids for M and N
|
||||
constexpr auto E_dims_M_ids = typename arithmetic_sequence_gen<0, NumDimM, 1>::type{};
|
||||
constexpr auto E_dims_N_ids =
|
||||
typename arithmetic_sequence_gen<NumDimM, NumDimM + NumDimN, 1>::type{};
|
||||
|
||||
// Dimensions for M and N
|
||||
const auto dims_M = get_container_subset(E_dims_M_N, E_dims_M_ids);
|
||||
const auto dims_N = get_container_subset(E_dims_M_N, E_dims_N_ids);
|
||||
|
||||
// naive tensor E[M0, M1, M2, ..., N0, N1, N2...] Discriptor
|
||||
const auto E_grid_desc_Ms_Ns =
|
||||
ck_tile::make_naive_tensor_descriptor(E_dims_M_N, E_strides_M_N);
|
||||
|
||||
// transformed tensor to flatten M and N dimensions [M_total = M0 * M1 * M2 * ... ,
|
||||
// N_total = N0 * N1 * N2 * ...]
|
||||
const auto E_grid_desc_Mflat_Nflat = ck_tile::transform_tensor_descriptor(
|
||||
E_grid_desc_Ms_Ns,
|
||||
make_tuple(make_merge_transform(dims_M), make_merge_transform(dims_N)),
|
||||
make_tuple(E_dims_M_ids, E_dims_N_ids),
|
||||
make_tuple(sequence<0>{}, sequence<1>{}));
|
||||
|
||||
return E_grid_desc_Mflat_Nflat;
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user