This commit is contained in:
Chao Liu
2021-08-16 20:36:47 +00:00
parent a91b68dfcb
commit 16effa767c
19 changed files with 99 additions and 91 deletions

View File

@@ -43,6 +43,7 @@ message(STATUS "Build with HIP ${hip_VERSION}")
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}") message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
# CMAKE_CXX_FLAGS # CMAKE_CXX_FLAGS
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
if(BUILD_DEV) if(BUILD_DEV)
string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything") string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything")
endif() endif()

View File

@@ -377,7 +377,7 @@ struct RightPad
// at compile-time // at compile-time
template <typename UpLengths, template <typename UpLengths,
typename Coefficients, typename Coefficients,
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false> typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
struct Embed struct Embed
{ {
static constexpr index_t NDimUp = UpLengths::Size(); static constexpr index_t NDimUp = UpLengths::Size();

View File

@@ -42,7 +42,7 @@ __host__ __device__ constexpr auto make_right_pad_transform(
template <typename UpLengths, template <typename UpLengths,
typename Coefficients, typename Coefficients,
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false> typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths, __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
const Coefficients& coefficients) const Coefficients& coefficients)
{ {

View File

@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms}; remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
} }
template <typename X, template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
typename... Xs,
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs) __host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
{ {
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...)); return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));

View File

@@ -37,7 +37,7 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
template <typename... Lengths, template <typename... Lengths,
typename... Strides, typename... Strides,
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false> typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
__host__ __device__ constexpr auto make_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths, __host__ __device__ constexpr auto make_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
const Tuple<Strides...>& strides) const Tuple<Strides...>& strides)
{ {

View File

@@ -22,24 +22,24 @@ namespace ck {
// 2. CThreadBuffer is StaticBuffer // 2. CThreadBuffer is StaticBuffer
// Also assume: // Also assume:
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization) // M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
template <index_t BlockSize, template <
typename FloatA, index_t BlockSize,
typename FloatB, typename FloatA,
typename FloatC, typename FloatB,
typename AKMBlockDesc, typename FloatC,
typename BKNBlockDesc, typename AKMBlockDesc,
index_t M1PerThreadM11, typename BKNBlockDesc,
index_t N1PerThreadN11, index_t M1PerThreadM11,
index_t KPerThread, index_t N1PerThreadN11,
index_t M1N1ThreadClusterM100, index_t KPerThread,
index_t M1N1ThreadClusterN100, index_t M1N1ThreadClusterM100,
index_t M1N1ThreadClusterM101, index_t M1N1ThreadClusterN100,
index_t M1N1ThreadClusterN101, index_t M1N1ThreadClusterM101,
index_t AThreadCopyScalarPerVector_M11, index_t M1N1ThreadClusterN101,
index_t BThreadCopyScalarPerVector_N11, index_t AThreadCopyScalarPerVector_M11,
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() && index_t BThreadCopyScalarPerVector_N11,
BKNBlockDesc::IsKnownAtCompileTime(), typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2 struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
{ {
using AIndex = MultiIndex<3>; using AIndex = MultiIndex<3>;

View File

@@ -38,9 +38,9 @@ template <index_t BlockSize,
// BM10BN10ThreadClusterBN101, ...> // BM10BN10ThreadClusterBN101, ...>
index_t AThreadCopyScalarPerVector_BM11, index_t AThreadCopyScalarPerVector_BM11,
index_t BThreadCopyScalarPerVector_BN11, index_t BThreadCopyScalarPerVector_BN11,
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() && typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(), BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2 struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
{ {
using AIndex = MultiIndex<3>; using AIndex = MultiIndex<3>;

View File

@@ -21,10 +21,10 @@ template <typename FloatA,
typename TKLengths, typename TKLengths,
typename TMLengths, typename TMLengths,
typename TNLengths, typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1 struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
{ {
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1() __device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1()
@@ -123,10 +123,10 @@ template <typename FloatA,
typename TKLengths, typename TKLengths,
typename TMLengths, typename TMLengths,
typename TNLengths, typename TNLengths,
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() && typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() && BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(), CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1 struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
{ {
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1() __device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()

View File

@@ -19,9 +19,9 @@ template <typename FloatA,
typename CDesc, typename CDesc,
index_t H, index_t H,
index_t W, index_t W,
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() && typename enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
CDesc::IsKnownAtCompileTime(), CDesc::IsKnownAtCompileTime(),
bool>::type = false> bool>::type = false>
struct ThreadwiseGemmDlops_km_kn_mn_v3 struct ThreadwiseGemmDlops_km_kn_mn_v3
{ {
template <typename ABuffer, template <typename ABuffer,

View File

@@ -15,7 +15,7 @@ namespace ck {
template <typename Data, template <typename Data,
typename Desc, typename Desc,
typename SliceLengths, typename SliceLengths,
typename std::enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceSet_v1 struct ThreadwiseTensorSliceSet_v1
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();

View File

@@ -57,7 +57,7 @@ template <typename SrcData,
InMemoryDataOperationEnum_t DstInMemOp, InMemoryDataOperationEnum_t DstInMemOp,
index_t DstScalarStrideInVector, index_t DstScalarStrideInVector,
bool DstResetCoordinateAfterRun, bool DstResetCoordinateAfterRun,
typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v1r3 struct ThreadwiseTensorSliceTransfer_v1r3
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
@@ -373,7 +373,7 @@ template <typename SrcData,
index_t SrcScalarPerVector, index_t SrcScalarPerVector,
index_t SrcScalarStrideInVector, index_t SrcScalarStrideInVector,
bool SrcResetCoordinateAfterRun, bool SrcResetCoordinateAfterRun,
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false> typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v2 struct ThreadwiseTensorSliceTransfer_v2
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();
@@ -1261,18 +1261,17 @@ struct ThreadwiseTensorSliceTransfer_v3
// 3. DstOriginIdx is known at compile-time // 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation // 4. use direct address calculation
// 3. vector access on src // 3. vector access on src
template < template <typename SrcData,
typename SrcData, typename DstData,
typename DstData, typename SrcDesc,
typename SrcDesc, typename DstDesc,
typename DstDesc, typename SliceLengths,
typename SliceLengths, typename DimAccessOrder,
typename DimAccessOrder, index_t SrcVectorDim,
index_t SrcVectorDim, index_t SrcScalarPerVector,
index_t SrcScalarPerVector, index_t SrcScalarStrideInVector,
index_t SrcScalarStrideInVector, typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false>
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v4 struct ThreadwiseTensorSliceTransfer_v4
{ {
static constexpr index_t nDim = SliceLengths::Size(); static constexpr index_t nDim = SliceLengths::Size();

View File

@@ -621,17 +621,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
// 3. DstOriginIdx is known at compile-time // 3. DstOriginIdx is known at compile-time
// 4. use direct address calculation // 4. use direct address calculation
// 3. vector access on src // 3. vector access on src
template < template <typename SrcData,
typename SrcData, typename DstData,
typename DstData, typename SrcDesc,
typename SrcDesc, typename DstDesc,
typename DstDesc, typename SliceLengths,
typename SliceLengths, typename DimAccessOrder,
typename DimAccessOrder, typename SrcVectorTensorLengths,
typename SrcVectorTensorLengths, typename SrcVectorTensorContiguousDimOrder,
typename SrcVectorTensorContiguousDimOrder, typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(), bool>::type = false>
bool>::type = false>
struct ThreadwiseTensorSliceTransfer_v4r1 struct ThreadwiseTensorSliceTransfer_v4r1
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};

View File

@@ -2,12 +2,13 @@
#define CK_C_STYLE_POINTER_CAST_HPP #define CK_C_STYLE_POINTER_CAST_HPP
#include "type.hpp" #include "type.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
template <typename PY, template <typename PY,
typename PX, typename PX,
typename std::enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false> typename enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
__host__ __device__ PY c_style_pointer_cast(PX p_x) __host__ __device__ PY c_style_pointer_cast(PX p_x)
{ {
#pragma clang diagnostic push #pragma clang diagnostic push

View File

@@ -14,6 +14,7 @@
#include "functional2.hpp" #include "functional2.hpp"
#include "functional3.hpp" #include "functional3.hpp"
#include "functional4.hpp" #include "functional4.hpp"
#include "enable_if.hpp"
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "math.hpp" #include "math.hpp"
#include "number.hpp" #include "number.hpp"

View File

@@ -3,6 +3,7 @@
#include "amd_buffer_addressing.hpp" #include "amd_buffer_addressing.hpp"
#include "c_style_pointer_cast.hpp" #include "c_style_pointer_cast.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
@@ -38,7 +39,7 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename std::enable_if< typename enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>
@@ -93,7 +94,7 @@ struct DynamicBuffer
} }
template <typename X, template <typename X,
typename std::enable_if< typename enable_if<
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type, is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value, typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
bool>::type = false> bool>::type = false>

View File

@@ -0,0 +1,13 @@
#ifndef CK_ENABLE_IF_HPP
#define CK_ENABLE_IF_HPP
namespace ck {
template <bool B, typename T = void>
using enable_if = std::enable_if<B, T>;
template <bool B, typename T = void>
using enable_if_t = typename std::enable_if<B, T>::type;
} // namespace ck
#endif

View File

@@ -5,6 +5,7 @@
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "number.hpp" #include "number.hpp"
#include "type.hpp" #include "type.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
namespace math { namespace math {
@@ -184,9 +185,7 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
return Number<r>{}; return Number<r>{};
} }
template <typename X, template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto gcd(X x, Ys... ys) __host__ __device__ constexpr auto gcd(X x, Ys... ys)
{ {
return gcd(x, gcd(ys...)); return gcd(x, gcd(ys...));
@@ -199,9 +198,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y)
return (x * y) / gcd(x, y); return (x * y) / gcd(x, y);
} }
template <typename X, template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
typename... Ys,
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr auto lcm(X x, Ys... ys) __host__ __device__ constexpr auto lcm(X x, Ys... ys)
{ {
return lcm(x, lcm(ys...)); return lcm(x, lcm(ys...));

View File

@@ -4,6 +4,7 @@
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "sequence.hpp" #include "sequence.hpp"
#include "type.hpp" #include "type.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
@@ -20,10 +21,9 @@ struct TupleElement
{ {
__host__ __device__ constexpr TupleElement() = default; __host__ __device__ constexpr TupleElement() = default;
template < template <typename T,
typename T, typename enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
typename std::enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v)) __host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
{ {
} }
@@ -58,17 +58,16 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
{ {
__host__ __device__ constexpr TupleImpl() = default; __host__ __device__ constexpr TupleImpl() = default;
template < template <typename Y,
typename Y, typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
typename std::enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 && !is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value, bool>::type = false>
bool>::type = false>
__host__ __device__ constexpr TupleImpl(Y&& y) __host__ __device__ constexpr TupleImpl(Y&& y)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))... : TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
{ {
} }
template <typename... Ys, typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false> template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
__host__ __device__ constexpr TupleImpl(Ys&&... ys) __host__ __device__ constexpr TupleImpl(Ys&&... ys)
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))... : TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
{ {
@@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
__host__ __device__ constexpr Tuple() = default; __host__ __device__ constexpr Tuple() = default;
template <typename Y, template <typename Y,
typename std::enable_if< typename enable_if<sizeof...(Xs) == 1 &&
sizeof...(Xs) == 1 && !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value, !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
bool>::type = false> bool>::type = false>
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y)) __host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
{ {
} }
template <typename... Ys, template <typename... Ys,
typename std::enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
bool>::type = false> false>
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...) __host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
{ {
} }

View File

@@ -2,6 +2,7 @@
#define CK_TYPE_HPP #define CK_TYPE_HPP
#include "integral_constant.hpp" #include "integral_constant.hpp"
#include "enable_if.hpp"
namespace ck { namespace ck {
@@ -39,9 +40,7 @@ struct is_known_at_compile_time<integral_constant<T, X>>
static constexpr bool value = true; static constexpr bool value = true;
}; };
template <typename Y, template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
typename X,
typename std::enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
__host__ __device__ constexpr Y as_type(X x) __host__ __device__ constexpr Y as_type(X x)
{ {
union AsType union AsType