adding implicit gemm

[ROCm/composable_kernel commit: e7b8705b91]
This commit is contained in:
Chao Liu
2019-01-15 18:11:41 -06:00
parent aa885b185d
commit 8b3c613be1
10 changed files with 510 additions and 231 deletions

View File

@@ -1,6 +1,22 @@
#pragma once
#include "common.cuh"
// this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
{
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
}
// this is ugly, only for 4d
template <unsigned S0, unsigned S1, unsigned S2, unsigned S3>
__host__ __device__ constexpr auto calculate_full_lengths(Sequence<S0, S1, S2, S3>)
{
static_assert((S0 % S1 == 0) && (S1 % S2 == 0) && (S2 % S3 == 0), "cannot be evenly divided!");
return Sequence<1, S0 / S1, S1 / S2, S2 / S3>{};
}
template <class Lengths, class Strides>
struct ConstantTensorDescriptor
{
@@ -69,24 +85,14 @@ struct ConstantTensorDescriptor
static_assert(nDim == 4, "nDim is not 4");
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3);
}
__host__ __device__ constexpr auto Condense() const
{
constexpr auto default_strides = calculate_default_strides(Lengths{});
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
}
};
// this is ugly, only for 4d
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
{
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
}
// this is ugly, only for 4d
template <unsigned S0, unsigned S1, unsigned S2, unsigned S3>
__host__ __device__ constexpr auto calculate_full_lengths(Sequence<S0, S1, S2, S3>)
{
static_assert((S0 % S1 == 0) && (S1 % S2 == 0) && (S2 % S3 == 0), "cannot be evenly divided!");
return Sequence<1, S0 / S1, S1 / S2, S2 / S3>{};
}
template <class Lengths>
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths)
{
@@ -124,4 +130,4 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
desc.GetStride(I1),
desc.GetStride(I2),
desc.GetStride(I3));
}
}