mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-13 09:45:56 +00:00
* wip: test suite for batched gemm multiple d gemm multiple d, working on gridwise implenentation * wip: many fixes in implementation of batched gemm gemm multiple d * wip: batched gemm gemm multiple d gridwise op compiling, not working yet * fix: incorrect d0 grid indexing in batched gemm gemm multipled * feat: add instances for batched gemm add relu gemm add * chore: configure instance with low vector transfer size for odd sizes * chore: add some more validation to device batched gemm gemm multiple d, and removed template parameter that didn't really make sense * fix: upate device_batched_gemm_gemm_wmma to work with new gridwise changes * fix: disable odd size tests on XDL archs * chore: removed temporary logging * chore: update some references to C tensor to E tensor * Tentative fix for example template params * Tentative fix for non-multi-D batched gemm gemm device impl. * Tentative fix for xdl example template params * Tentative fix for profiler build on gfx90a * chore: improve device batched gemm gemm multi D comment to include all ops and dimensions * chore: explicitly call ck::make_tuple to prevent issues when std::make_tuple would apply * fix: make the gemm1 data types match what happens in the device op * feat: add d0s/d1s datatypes and layouts to the device op type string * chore: change element-wise op so addition happens in fp32 * chore: add static asserts for gemm0/gemm1 calculated wave sizes * chore: also updated other element-wise ops to use fp32 calculations * chore: log number of supported instances * chore: update instance comment * chore: disable kernel timing in example by default * fix: gemm1 wave size calculation * fix: make sure batched gemm multiple d gemm multiple d profiler performs correct type conversions * chore: remove increased tolerance in batched gemm gemm multiple d example * chore: add comment explaining that verification fails for certain input values * chore: clarify instance comment --------- Co-authored-by: kiefer <kiefer.van.teutem@streamhpc.com>
211 lines
6.0 KiB
C++
211 lines
6.0 KiB
C++
// Copyright (c) Advanced Micro Devices, Inc., or its affiliates.
|
|
// SPDX-License-Identifier: MIT
|
|
|
|
#pragma once
|
|
|
|
#include "functional4.hpp"
|
|
#include "tuple.hpp"
|
|
#ifndef CK_CODE_GEN_RTC
|
|
#include "is_detected.hpp"
|
|
#endif
|
|
|
|
namespace ck {
|
|
|
|
template <typename F, index_t... ids>
|
|
__host__ __device__ constexpr auto generate_tuple_for(F&& f, Sequence<ids...>)
|
|
{
|
|
return ck::make_tuple(f(Number<ids>{})...);
|
|
}
|
|
|
|
template <typename F, index_t N>
|
|
__host__ __device__ constexpr auto generate_tuple(F&& f, Number<N>)
|
|
{
|
|
return generate_tuple_for(f, make_index_sequence<N>{});
|
|
}
|
|
|
|
template <typename F, index_t N>
|
|
__host__ __device__ constexpr auto generate_tuple(F&& f, LongNumber<N>)
|
|
{
|
|
return unpack([&f](auto&&... xs) { return make_tuple(f(xs)...); },
|
|
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
|
}
|
|
|
|
template <typename F, index_t N>
|
|
__host__ __device__ constexpr auto generate_tie(F&& f, Number<N>)
|
|
{
|
|
return unpack([&f](auto&&... xs) { return tie(f(xs)...); },
|
|
typename arithmetic_sequence_gen<0, N, 1>::type{});
|
|
}
|
|
|
|
// tx and ty are tuple of references, return type of will tuple of referennce (not rvalue)
|
|
template <typename... X, typename... Y>
|
|
__host__ __device__ constexpr auto concat_tuple_of_reference(const Tuple<X&...>& tx,
|
|
const Tuple<Y&...>& ty)
|
|
{
|
|
return unpack2(
|
|
[&](auto&&... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
|
|
tx,
|
|
ty);
|
|
}
|
|
|
|
template <typename... X, typename... Y>
|
|
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuple<Y...>& ty)
|
|
{
|
|
return unpack2(
|
|
[&](auto... zs) { return Tuple<decltype(zs)...>{ck::forward<decltype(zs)>(zs)...}; },
|
|
tx,
|
|
ty);
|
|
}
|
|
|
|
// Support any number of tuples to concat (also 1)
|
|
template <typename... X>
|
|
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx)
|
|
{
|
|
return tx;
|
|
}
|
|
|
|
template <typename... X, typename... Tuples>
|
|
__host__ __device__ constexpr auto concat_tuple(const Tuple<X...>& tx, const Tuples&... tuples)
|
|
{
|
|
return concat_tuple(tx, concat_tuple(tuples...));
|
|
}
|
|
|
|
namespace detail {
|
|
|
|
template <typename F, typename X, index_t... Is>
|
|
__host__ __device__ constexpr auto transform_tuples_impl(F f, const X& x, Sequence<Is...>)
|
|
{
|
|
return make_tuple(f(x.At(Number<Is>{}))...);
|
|
}
|
|
|
|
template <typename F, typename X, typename Y, index_t... Is>
|
|
__host__ __device__ constexpr auto
|
|
transform_tuples_impl(F f, const X& x, const Y& y, Sequence<Is...>)
|
|
{
|
|
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}))...);
|
|
}
|
|
|
|
template <typename F, typename X, typename Y, typename Z, index_t... Is>
|
|
__host__ __device__ constexpr auto
|
|
transform_tuples_impl(F f, const X& x, const Y& y, const Z& z, Sequence<Is...>)
|
|
{
|
|
return make_tuple(f(x.At(Number<Is>{}), y.At(Number<Is>{}), z.At(Number<Is>{}))...);
|
|
}
|
|
|
|
} // namespace detail
|
|
|
|
template <typename F, typename X>
|
|
__host__ __device__ constexpr auto transform_tuples(F f, const X& x)
|
|
{
|
|
return detail::transform_tuples_impl(
|
|
f, x, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
|
}
|
|
|
|
template <typename F, typename X, typename Y>
|
|
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y)
|
|
{
|
|
return detail::transform_tuples_impl(
|
|
f, x, y, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
|
}
|
|
|
|
template <typename F, typename X, typename Y, typename Z>
|
|
__host__ __device__ constexpr auto transform_tuples(F f, const X& x, const Y& y, const Z& z)
|
|
{
|
|
return detail::transform_tuples_impl(
|
|
f, x, y, z, typename arithmetic_sequence_gen<0, X::Size(), 1>::type{});
|
|
}
|
|
|
|
// By default unroll to the flatten
|
|
template <index_t Depth = 0, index_t MaxDepth = -1>
|
|
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<>& element)
|
|
{
|
|
return element;
|
|
}
|
|
|
|
template <index_t Depth = 0, index_t MaxDepth = -1, typename T>
|
|
__host__ __device__ constexpr auto UnrollNestedTuple(const T& element)
|
|
{
|
|
return make_tuple(element);
|
|
}
|
|
|
|
template <index_t Depth = 0, index_t MaxDepth = -1, typename... Ts>
|
|
__host__ __device__ constexpr auto UnrollNestedTuple(const Tuple<Ts...>& tuple)
|
|
{
|
|
if constexpr(Depth == MaxDepth)
|
|
{
|
|
return tuple;
|
|
}
|
|
else
|
|
{
|
|
return unpack(
|
|
[&](auto&&... ts) {
|
|
return concat_tuple(UnrollNestedTuple<Depth + 1, MaxDepth>(ts)...);
|
|
},
|
|
tuple);
|
|
}
|
|
}
|
|
|
|
template <typename... Ts>
|
|
__host__ __device__ constexpr auto TupleReverse(const Tuple<Ts...>& tuple)
|
|
{
|
|
return generate_tuple(
|
|
[&](auto i) {
|
|
using Idx = Number<Tuple<Ts...>::Size() - i - 1>;
|
|
return tuple.At(Idx{});
|
|
},
|
|
Number<Tuple<Ts...>::Size()>{});
|
|
}
|
|
|
|
// Reduce tuple values in specific range using Function
|
|
template <index_t Idx, index_t End, typename F, typename... Ts>
|
|
__host__ __device__ constexpr auto TupleReduce(F&& f, const Tuple<Ts...>& tuple)
|
|
{
|
|
static_assert(Idx < End, "Wrong parameters for TupleReduce");
|
|
if constexpr(Idx + 1 == End)
|
|
{
|
|
return tuple.At(Number<Idx>{});
|
|
}
|
|
else
|
|
{
|
|
return f(tuple.At(Number<Idx>{}), TupleReduce<Idx + 1, End>(f, tuple));
|
|
}
|
|
}
|
|
|
|
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
|
template <typename T>
|
|
using is_tuple = decltype(ck::declval<T&>().IsTuple());
|
|
#endif
|
|
|
|
template <typename... Ts>
|
|
__host__ __device__ constexpr auto IsNestedTuple(const Tuple<Ts...>&)
|
|
{
|
|
#if !defined(__HIPCC_RTC__) || !defined(CK_CODE_GEN_RTC)
|
|
return (is_detected<is_tuple, Ts>::value || ...);
|
|
#endif
|
|
}
|
|
|
|
template <index_t depth = 0, typename T>
|
|
__host__ __device__ constexpr auto TupleDepth(const T&)
|
|
{
|
|
return depth;
|
|
}
|
|
|
|
template <index_t depth = 0, typename... Ts>
|
|
__host__ __device__ constexpr auto TupleDepth(const Tuple<Ts...>&)
|
|
{
|
|
return math::max(TupleDepth<depth + 1>(Ts{})...);
|
|
}
|
|
|
|
template <index_t from, index_t to, typename... Ts>
|
|
__host__ __device__ constexpr auto TupleSlice(const Tuple<Ts...>& tuple)
|
|
{
|
|
return generate_tuple(
|
|
[&](auto i) {
|
|
using Idx = Number<from + i>;
|
|
return tuple.At(Idx{});
|
|
},
|
|
Number<to - from>{});
|
|
}
|
|
|
|
} // namespace ck
|