mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-03 21:21:22 +00:00
[CK Tile] contraction multi d - kernel & example (#2901)
* Initial commit. create batched_contraction_kernel file * initial problem definition * implement initial example to launch kernel * add universal gemm to contraction. initial phase * complete implementation for special case all Dims are 1 and no Ds * clean code * initial changes to support multi dimensional G * more progress in implementing multiple G * tmp commit * manage dynamic NumDimG in kernel * improving example for multi M,N,K,G handling. start generalizing kernel. it is a temporary commit * implement the example for general Multi dimension G M N K and test different reference calculation algorithms * 2 functions for reference using multi dimensional and flat indexing * clean the code for muti dimentional G, M, N, K contraction and add some logs * Add Make descriptor function in kernel for merging Ms, Ns, Ks for A, B, E * some cleaning on kernel * clean the code for calculating the offsets from flatten batch number * Start adding MultiD support to kernel and example * more changes to manage multi D in kernel and example * manage passing multi d to kernel and testing. * complete multi D support in kernel. modify example code to support it * Correct algorithm to calc the correct offset values for D tensor batches and some code cleaning * Minor fix * Generalize example code for variable NumD tensors and apply cleanup based on review feedback * Refactored code and addressed review feedback * refactoring, cleaning, add documents, in kernel side and example codes * Optimize batch offset calculation in kernel * Inline CalculateBatchOffset in batched contraction kernel, update CHANGELOG.md --------- Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
|
||||
#pragma once
|
||||
|
||||
#include "ck_tile/core.hpp"
|
||||
|
||||
namespace ck_tile {
|
||||
|
||||
template <typename ADataType_,
|
||||
typename BDataType_,
|
||||
typename DsDataType_,
|
||||
typename EDataType_,
|
||||
ck_tile::index_t NumDimG_,
|
||||
ck_tile::index_t NumDimM_,
|
||||
ck_tile::index_t NumDimN_,
|
||||
ck_tile::index_t NumDimK_,
|
||||
ck_tile::index_t NumDTensor_>
|
||||
struct BatchedContractionProblem
|
||||
{
|
||||
using ADataType = ck_tile::remove_cvref_t<ADataType_>;
|
||||
using BDataType = ck_tile::remove_cvref_t<BDataType_>;
|
||||
using DsDataType = ck_tile::remove_cvref_t<DsDataType_>;
|
||||
using EDataType = ck_tile::remove_cvref_t<EDataType_>;
|
||||
|
||||
static constexpr ck_tile::index_t NumDimG = NumDimG_;
|
||||
static constexpr ck_tile::index_t NumDimM = NumDimM_;
|
||||
static constexpr ck_tile::index_t NumDimN = NumDimN_;
|
||||
static constexpr ck_tile::index_t NumDimK = NumDimK_;
|
||||
static constexpr ck_tile::index_t NumDTensor = NumDTensor_;
|
||||
};
|
||||
|
||||
} // namespace ck_tile
|
||||
Reference in New Issue
Block a user