mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-12 01:10:17 +00:00
amd build
This commit is contained in:
@@ -313,14 +313,14 @@ struct TensorCoordinate
|
||||
private:
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantTensorDescriptor<Ts...>)
|
||||
{
|
||||
return NormalTensorCoordinate<ConstantTensorDescriptor<Ts...>>();
|
||||
}
|
||||
|
||||
template <class... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor<Ts...>)
|
||||
{
|
||||
return MergedTensorCoordinate<ConstantMergedTensorDescriptor<Ts...>>();
|
||||
}
|
||||
|
||||
@@ -188,7 +188,7 @@ struct TensorCoordinate_v2
|
||||
private:
|
||||
template <typename... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(NativeTensorDescriptor<Ts...>)
|
||||
{
|
||||
return NativeTensorCoordinate<NativeTensorDescriptor<Ts...>>(
|
||||
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
|
||||
@@ -196,7 +196,7 @@ struct TensorCoordinate_v2
|
||||
|
||||
template <typename... Ts>
|
||||
__host__ __device__ static constexpr auto
|
||||
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
|
||||
MakeDummyTensorCoordinate(TransformedTensorDescriptor<Ts...>)
|
||||
{
|
||||
return TransformedTensorCoordinate<TransformedTensorDescriptor<Ts...>>(
|
||||
make_zero_array<index_t, TensorDesc::GetNumOfDimension()>());
|
||||
|
||||
@@ -13,13 +13,20 @@
|
||||
|
||||
namespace ck {
|
||||
|
||||
using unsigned_t = uint32_t;
|
||||
using signed_t = int;
|
||||
|
||||
#if 0 // debug
|
||||
using index_t = unsigned_t;
|
||||
#else
|
||||
using index_t = signed_t;
|
||||
#endif
|
||||
|
||||
// For some reason, HIP compiler need this definition to generate optimal load and store
|
||||
// instruction
|
||||
typedef float float2_t __attribute__((ext_vector_type(2)));
|
||||
typedef float float4_t __attribute__((ext_vector_type(4)));
|
||||
|
||||
using index_t = uint32_t;
|
||||
|
||||
template <class T>
|
||||
__device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1)
|
||||
{
|
||||
|
||||
@@ -51,11 +51,9 @@ struct TupleImpl;
|
||||
template <index_t... Is, typename... Xs>
|
||||
struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>...
|
||||
{
|
||||
#if 1
|
||||
__host__ __device__ explicit constexpr TupleImpl() : TupleElement<TupleElementKey<Is>, Xs>()...
|
||||
{
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename... Ys>
|
||||
__host__ __device__ explicit constexpr TupleImpl(Ys&&... ys)
|
||||
@@ -95,14 +93,14 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
|
||||
__host__ __device__ constexpr const auto& At(Number<I>) const
|
||||
{
|
||||
static_assert(I < base::Size(), "wrong! out of range");
|
||||
return GetElementByKey(detail::TupleElementKey<I>{});
|
||||
return base::GetElementByKey(detail::TupleElementKey<I>{});
|
||||
}
|
||||
|
||||
template <index_t I>
|
||||
__host__ __device__ constexpr auto& At(Number<I>)
|
||||
{
|
||||
static_assert(I < base::Size(), "wrong! out of range");
|
||||
return GetElementByKey(detail::TupleElementKey<I>{});
|
||||
return base::GetElementByKey(detail::TupleElementKey<I>{});
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -74,20 +74,20 @@ int main(int argc, char* argv[])
|
||||
{
|
||||
using namespace ck;
|
||||
|
||||
#if 0
|
||||
constexpr index_t N = 32;
|
||||
constexpr index_t C = 8;
|
||||
constexpr index_t HI = 1;
|
||||
constexpr index_t WI = 1;
|
||||
constexpr index_t K = 128;
|
||||
constexpr index_t Y = 1;
|
||||
constexpr index_t X = 1;
|
||||
#if 1
|
||||
constexpr index_t N = 256;
|
||||
constexpr index_t C = 64;
|
||||
constexpr index_t HI = 17;
|
||||
constexpr index_t WI = 17;
|
||||
constexpr index_t K = 256;
|
||||
constexpr index_t Y = 17;
|
||||
constexpr index_t X = 17;
|
||||
|
||||
using ConvStrides = Sequence<1, 1>;
|
||||
using ConvDilations = Sequence<1, 1>;
|
||||
|
||||
using LeftPads = Sequence<1, 1>;
|
||||
using RightPads = Sequence<0, 0>;
|
||||
using LeftPads = Sequence<0, 3>;
|
||||
using RightPads = Sequence<0, 3>;
|
||||
#elif 1
|
||||
// 3x3, 34x34
|
||||
constexpr index_t N = 64;
|
||||
|
||||
Reference in New Issue
Block a user