[CK Tile] contraction multi d - kernel & example (#2901)

* Initial commit. create batched_contraction_kernel file

* initial problem definition

* implement initial example to launch kernel

* add universal gemm to contraction. initial phase

* complete implementation for special case all Dims are 1 and no Ds

* clean code

* initial changes to support multi dimensional G

* more progress in implementing multiple G

* tmp commit

* manage dynamic NumDimG in kernel

* improving example for multi M,N,K,G handling. start generalizing kernel. it is a temporary commit

* implement the example for general Multi dimension G M N K and test different reference calculation algorithms

* 2 functions for reference using multi dimensional and flat indexing

* clean the code for muti dimentional G, M, N, K contraction and add some logs

* Add Make descriptor function in kernel for merging Ms, Ns, Ks for A, B, E

* some cleaning on kernel

* clean the code for  calculating the offsets from flatten batch number

* Start adding MultiD support to kernel and example

* more changes to manage multi D in kernel and example

* manage passing multi d to kernel and testing.

* complete multi D support in kernel. modify example code to support it

* Correct algorithm to calc the correct offset values for D tensor batches and some code cleaning

* Minor fix

* Generalize example code for variable NumD tensors and apply cleanup based on review feedback

* Refactored code and addressed review feedback

* refactoring, cleaning, add documents, in kernel side and example codes

* Optimize batch offset calculation in kernel

* Inline CalculateBatchOffset in batched contraction kernel, update CHANGELOG.md

---------

Co-authored-by: Adam Osewski <19374865+aosewski@users.noreply.github.com>
This commit is contained in:
msaffari-amd
2025-10-13 12:30:28 +02:00
committed by GitHub
parent 95bdc7410c
commit e9f0cc83a8
11 changed files with 1802 additions and 0 deletions

View File

@@ -0,0 +1,32 @@
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename ADataType_,
typename BDataType_,
typename DsDataType_,
typename EDataType_,
ck_tile::index_t NumDimG_,
ck_tile::index_t NumDimM_,
ck_tile::index_t NumDimN_,
ck_tile::index_t NumDimK_,
ck_tile::index_t NumDTensor_>
struct BatchedContractionProblem
{
using ADataType = ck_tile::remove_cvref_t<ADataType_>;
using BDataType = ck_tile::remove_cvref_t<BDataType_>;
using DsDataType = ck_tile::remove_cvref_t<DsDataType_>;
using EDataType = ck_tile::remove_cvref_t<EDataType_>;
static constexpr ck_tile::index_t NumDimG = NumDimG_;
static constexpr ck_tile::index_t NumDimM = NumDimM_;
static constexpr ck_tile::index_t NumDimN = NumDimN_;
static constexpr ck_tile::index_t NumDimK = NumDimK_;
static constexpr ck_tile::index_t NumDTensor = NumDTensor_;
};
} // namespace ck_tile