diff --git a/composable_kernel/include/tensor_description/tensor_coordinate.hpp b/composable_kernel/include/tensor_description/tensor_coordinate.hpp index 4600b682ac..1a2774b589 100644 --- a/composable_kernel/include/tensor_description/tensor_coordinate.hpp +++ b/composable_kernel/include/tensor_description/tensor_coordinate.hpp @@ -313,14 +313,14 @@ struct TensorCoordinate private: template __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(ConstantTensorDescriptor) + MakeDummyTensorCoordinate(ConstantTensorDescriptor) { return NormalTensorCoordinate>(); } template __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor) + MakeDummyTensorCoordinate(ConstantMergedTensorDescriptor) { return MergedTensorCoordinate>(); } diff --git a/composable_kernel/include/tensor_description/tensor_coordinate_v2.hpp b/composable_kernel/include/tensor_description/tensor_coordinate_v2.hpp index 2a35457bda..0330b22438 100644 --- a/composable_kernel/include/tensor_description/tensor_coordinate_v2.hpp +++ b/composable_kernel/include/tensor_description/tensor_coordinate_v2.hpp @@ -188,7 +188,7 @@ struct TensorCoordinate_v2 private: template __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(NativeTensorDescriptor) + MakeDummyTensorCoordinate(NativeTensorDescriptor) { return NativeTensorCoordinate>( make_zero_array()); @@ -196,7 +196,7 @@ struct TensorCoordinate_v2 template __host__ __device__ static constexpr auto - MakeDummyTensorCoordinate(TransformedTensorDescriptor) + MakeDummyTensorCoordinate(TransformedTensorDescriptor) { return TransformedTensorCoordinate>( make_zero_array()); diff --git a/composable_kernel/include/utility/config_amd.hpp.in b/composable_kernel/include/utility/config_amd.hpp.in index a7762a59b4..9b1542e224 100644 --- a/composable_kernel/include/utility/config_amd.hpp.in +++ b/composable_kernel/include/utility/config_amd.hpp.in @@ -13,13 +13,20 @@ namespace ck { +using unsigned_t = uint32_t; +using signed_t = int; + +#if 0 // debug +using index_t = unsigned_t; +#else +using index_t = signed_t; +#endif + // For some reason, HIP compiler need this definition to generate optimal load and store // instruction typedef float float2_t __attribute__((ext_vector_type(2))); typedef float float4_t __attribute__((ext_vector_type(4))); -using index_t = uint32_t; - template __device__ void fused_multiply_accumulate(T& d, const T& s0, const T& s1) { diff --git a/composable_kernel/include/utility/tuple.hpp b/composable_kernel/include/utility/tuple.hpp index c26cad2ae6..27175fe625 100644 --- a/composable_kernel/include/utility/tuple.hpp +++ b/composable_kernel/include/utility/tuple.hpp @@ -51,11 +51,9 @@ struct TupleImpl; template struct TupleImpl, Xs...> : TupleElement, Xs>... { -#if 1 __host__ __device__ explicit constexpr TupleImpl() : TupleElement, Xs>()... { } -#endif template __host__ __device__ explicit constexpr TupleImpl(Ys&&... ys) @@ -95,14 +93,14 @@ struct Tuple : detail::TupleImpl) const { static_assert(I < base::Size(), "wrong! out of range"); - return GetElementByKey(detail::TupleElementKey{}); + return base::GetElementByKey(detail::TupleElementKey{}); } template __host__ __device__ constexpr auto& At(Number) { static_assert(I < base::Size(), "wrong! out of range"); - return GetElementByKey(detail::TupleElementKey{}); + return base::GetElementByKey(detail::TupleElementKey{}); } }; diff --git a/driver/src/driver.cpp b/driver/src/driver.cpp index a47073c5e7..caecfce7fa 100644 --- a/driver/src/driver.cpp +++ b/driver/src/driver.cpp @@ -74,20 +74,20 @@ int main(int argc, char* argv[]) { using namespace ck; -#if 0 - constexpr index_t N = 32; - constexpr index_t C = 8; - constexpr index_t HI = 1; - constexpr index_t WI = 1; - constexpr index_t K = 128; - constexpr index_t Y = 1; - constexpr index_t X = 1; +#if 1 + constexpr index_t N = 256; + constexpr index_t C = 64; + constexpr index_t HI = 17; + constexpr index_t WI = 17; + constexpr index_t K = 256; + constexpr index_t Y = 17; + constexpr index_t X = 17; using ConvStrides = Sequence<1, 1>; using ConvDilations = Sequence<1, 1>; - using LeftPads = Sequence<1, 1>; - using RightPads = Sequence<0, 0>; + using LeftPads = Sequence<0, 3>; + using RightPads = Sequence<0, 3>; #elif 1 // 3x3, 34x34 constexpr index_t N = 64;