mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-18 03:49:41 +00:00
@@ -1,6 +1,22 @@
|
||||
#pragma once
|
||||
#include "common.cuh"
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned S0, unsigned S1, unsigned S2, unsigned S3>
|
||||
__host__ __device__ constexpr auto calculate_full_lengths(Sequence<S0, S1, S2, S3>)
|
||||
{
|
||||
static_assert((S0 % S1 == 0) && (S1 % S2 == 0) && (S2 % S3 == 0), "cannot be evenly divided!");
|
||||
|
||||
return Sequence<1, S0 / S1, S1 / S2, S2 / S3>{};
|
||||
}
|
||||
|
||||
template <class Lengths, class Strides>
|
||||
struct ConstantTensorDescriptor
|
||||
{
|
||||
@@ -69,24 +85,14 @@ struct ConstantTensorDescriptor
|
||||
static_assert(nDim == 4, "nDim is not 4");
|
||||
return i0 * GetStride(I0) + i1 * GetStride(I1) + i2 * GetStride(I2) + i3 * GetStride(I3);
|
||||
}
|
||||
|
||||
__host__ __device__ constexpr auto Condense() const
|
||||
{
|
||||
constexpr auto default_strides = calculate_default_strides(Lengths{});
|
||||
return ConstantTensorDescriptor<Lengths, decltype(default_strides)>{};
|
||||
}
|
||||
};
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned L0, unsigned L1, unsigned L2, unsigned L3>
|
||||
__host__ __device__ constexpr auto calculate_default_strides(Sequence<L0, L1, L2, L3>)
|
||||
{
|
||||
return Sequence<L1 * L2 * L3, L2 * L3, L3, 1>{};
|
||||
}
|
||||
|
||||
// this is ugly, only for 4d
|
||||
template <unsigned S0, unsigned S1, unsigned S2, unsigned S3>
|
||||
__host__ __device__ constexpr auto calculate_full_lengths(Sequence<S0, S1, S2, S3>)
|
||||
{
|
||||
static_assert((S0 % S1 == 0) && (S1 % S2 == 0) && (S2 % S3 == 0), "cannot be evenly divided!");
|
||||
|
||||
return Sequence<1, S0 / S1, S1 / S2, S2 / S3>{};
|
||||
}
|
||||
|
||||
template <class Lengths>
|
||||
__host__ __device__ constexpr auto make_ConstantTensorDescriptor(Lengths)
|
||||
{
|
||||
@@ -124,4 +130,4 @@ __host__ __device__ void print_ConstantTensorDescriptor(TDesc, const char* s)
|
||||
desc.GetStride(I1),
|
||||
desc.GetStride(I2),
|
||||
desc.GetStride(I3));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user