[CK_TILE] Multiple-D GEMM example (#2219)

* Multiple d, initial commit

* Check Ds Layout

* Readme and clang format

* Update branch & conflicts

* Multiple D - fix clang-formatter

* Rename elemetwise_op

* Fix CI

* Code review part1

* Remove printf

* Remove unnecessary comment

* Add new tests with Col layout

* Review part 2

* Added support for Multiple D GEMM

* Update comment

* Remove maybe_unused

* Clang-format

* Review part 3

* Add comment to function

* Add comment to function: another

* Take number of params for a refrence function

* Remove additional d param for 0 tensor

* Change name of function

* Fix CI fails

[ROCm/composable_kernel commit: bd96ac9742]
This commit is contained in:
Mateusz Ozga
2025-06-13 19:39:11 +02:00
committed by GitHub
parent ea36ae016e
commit 044a8560f7
34 changed files with 2267 additions and 285 deletions

View File

@@ -59,6 +59,38 @@ CK_TILE_DEVICE auto tile_elementwise_in(const InElementFunc& in_element_func,
return out_dstr_tensor;
}
/**
* @brief Template function that "unpacks" a tuple and applies an element-wise operation.
*
* @param in_element_func Function to apply element-wise.
* @param t Any container containing elements to process, with known size and
* tuple-like semantic.
* @return Calls tile_elementwise_inout with unpacked tuple elements.
*/
template <typename InElementFunc, typename Tuple, size_t... I>
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
const Tuple& t,
std::index_sequence<I...>)
{
return tile_elementwise_inout(in_element_func, t[number<I>{}]...);
}
/**
* @brief Template function that "unpacks" a tuple and applies an element-wise operation.
*
* @param in_element_func Function to apply element-wise.
* @param t Any container containing elements to process, with known size and
* tuple-like semantic.
* @return Calls the overloaded function, passing an index sequence.
*/
template <typename InElementFunc, typename Tuple>
CK_TILE_DEVICE auto tile_elementwise_inout_unpack(const InElementFunc& in_element_func,
const Tuple& t)
{
static constexpr auto size = Tuple::size();
return tile_elementwise_inout_unpack(in_element_func, t, std::make_index_sequence<size>{});
}
template <typename DstrTensors, typename T>
CK_TILE_DEVICE void set_tile(DstrTensors& dstr_tensor, const T& value)
{