mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-19 04:19:36 +00:00
added implicit gemm v1r3 lds_double_buffer NCHW * CYXK = KNHW, reworked static functionals
[ROCm/composable_kernel commit: 569ad66e2a]
This commit is contained in:
@@ -1,80 +1,30 @@
|
||||
#pragma once
|
||||
#include "common.hip.hpp"
|
||||
|
||||
// this is ugly, only for 2d
|
||||
template <index_t L0, index_t L1>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1>)
|
||||
template <class PreviousStrides, class RemainLengths>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, RemainLengths)
|
||||
{
|
||||
return Sequence<L1, 1>{};
|
||||
constexpr index_t previous_stride = PreviousStrides{}.Front();
|
||||
constexpr index_t current_length = RemainLengths{}.Back();
|
||||
constexpr index_t current_stride = current_length * previous_stride;
|
||||
|
||||
return calculate_default_strides_impl(PreviousStrides{}.PushFront(Number<current_stride>{}),
|
||||
RemainLengths{}.PopBack());
|
||||
}
|
||||
|
||||
// this is ugly, only for 3d
|
||||
template <index_t L0, index_t L1, index_t L2>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2>)
|
||||
template <class PreviousStrides, index_t L0, index_t L1>
|
||||
__host__ __device__ constexpr auto calculate_default_strides_impl(PreviousStrides, Sequence<L0, L1>)
|
||||
{
|
||||
return Sequence<L1 * L2, L2, 1>{};
|
||||
constexpr index_t previous_stride = PreviousStrides{}.Front();
|
||||
constexpr index_t current_stride = L1 * previous_stride;
|
||||
|
||||
return PreviousStrides{}.PushFront(Number<current_stride>{});
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <index_t L0, index_t L1, index_t L2, index_t L3>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Lengths)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 6d
|
||||
template <index_t L0, index_t L1, index_t L2, index_t L3, index_t L4, index_t L5>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3 * L4 * L5, L2 * L3 * L4 * L5, L3 * L4 * L5, L4 * L5, L5, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 8d
|
||||
template <index_t L0,
|
||||
index_t L1,
|
||||
index_t L2,
|
||||
index_t L3,
|
||||
index_t L4,
|
||||
index_t L5,
|
||||
index_t L6,
|
||||
index_t L7>
|
||||
__host__ __device__ constexpr auto
|
||||
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3 * L4 * L5 * L6 * L7,
|
||||
L2 * L3 * L4 * L5 * L6 * L7,
|
||||
L3 * L4 * L5 * L6 * L7,
|
||||
L4 * L5 * L6 * L7,
|
||||
L5 * L6 * L7,
|
||||
L6 * L7,
|
||||
L7,
|
||||
1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 8d
|
||||
template <index_t L0,
|
||||
index_t L1,
|
||||
index_t L2,
|
||||
index_t L3,
|
||||
index_t L4,
|
||||
index_t L5,
|
||||
index_t L6,
|
||||
index_t L7,
|
||||
index_t L8,
|
||||
index_t L9>
|
||||
__host__ __device__ constexpr auto
|
||||
calculate_default_strides(Sequence<L0, L1, L2, L3, L4, L5, L6, L7, L8, L9>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3 * L4 * L5 * L6 * L7 * L8 * L9,
|
||||
L2 * L3 * L4 * L5 * L6 * L7 * L8 * L9,
|
||||
L3 * L4 * L5 * L6 * L7 * L8 * L9,
|
||||
L4 * L5 * L6 * L7 * L8 * L9,
|
||||
L5 * L6 * L7 * L8 * L9,
|
||||
L6 * L7 * L8 * L9,
|
||||
L7 * L8 * L9,
|
||||
L8 * L9,
|
||||
L9,
|
||||
1>{};
|
||||
return calculate_default_strides_impl(Sequence<1>{}, Lengths{});
|
||||
}
|
||||
|
||||
// this is ugly, only for 2d
|
||||
@@ -186,6 +136,14 @@ struct ConstantTensorDescriptor
|
||||
return Get1dIndex(multi_id);
|
||||
}
|
||||
|
||||
template <index_t... Is>
|
||||
__host__ __device__ static constexpr index_t Get1dIndex(Sequence<Is...> multi_id)
|
||||
{
|
||||
static_assert(sizeof...(Is) == nDim, "wrong! Dimension not consistent");
|
||||
|
||||
return Get1dIndex(Is...);
|
||||
}
|
||||
|
||||
__host__ __device__ static Array<index_t, nDim> GetMultiIndex(index_t id)
|
||||
{
|
||||
Array<index_t, nDim> multi_id;
|
||||
|
||||
Reference in New Issue
Block a user