mirror of
https://github.com/ROCm/composable_kernel.git
synced 2026-05-11 17:00:18 +00:00
refactor
This commit is contained in:
@@ -43,6 +43,7 @@ message(STATUS "Build with HIP ${hip_VERSION}")
|
||||
message("HALF_INCLUDE_DIR: ${HALF_INCLUDE_DIR}")
|
||||
|
||||
# CMAKE_CXX_FLAGS
|
||||
SET(BUILD_DEV ON CACHE BOOL "BUILD_DEV")
|
||||
if(BUILD_DEV)
|
||||
string(APPEND CMAKE_CXX_FLAGS " -Werror -Weverything")
|
||||
endif()
|
||||
|
||||
@@ -377,7 +377,7 @@ struct RightPad
|
||||
// at compile-time
|
||||
template <typename UpLengths,
|
||||
typename Coefficients,
|
||||
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
||||
typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
||||
struct Embed
|
||||
{
|
||||
static constexpr index_t NDimUp = UpLengths::Size();
|
||||
|
||||
@@ -42,7 +42,7 @@ __host__ __device__ constexpr auto make_right_pad_transform(
|
||||
|
||||
template <typename UpLengths,
|
||||
typename Coefficients,
|
||||
typename std::enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
||||
typename enable_if<UpLengths::Size() == Coefficients::Size(), bool>::type = false>
|
||||
__host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_lengths,
|
||||
const Coefficients& coefficients)
|
||||
{
|
||||
|
||||
@@ -454,9 +454,7 @@ __host__ __device__ constexpr auto make_single_stage_tensor_adaptor(const Transf
|
||||
remove_cv_t<decltype(top_dim_hidden_ids)>>{transforms};
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Xs,
|
||||
typename std::enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||
template <typename X, typename... Xs, typename enable_if<sizeof...(Xs) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr auto chain_tensor_adaptors(const X& x, const Xs&... xs)
|
||||
{
|
||||
return chain_tensor_adaptors(x, chain_tensor_adaptors(xs...));
|
||||
|
||||
@@ -37,7 +37,7 @@ __host__ __device__ constexpr auto calculate_element_space_size_impl(const Lengt
|
||||
|
||||
template <typename... Lengths,
|
||||
typename... Strides,
|
||||
typename std::enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
typename enable_if<sizeof...(Lengths) == sizeof...(Strides), bool>::type = false>
|
||||
__host__ __device__ constexpr auto make_naive_tensor_descriptor_v2(const Tuple<Lengths...>& lengths,
|
||||
const Tuple<Strides...>& strides)
|
||||
{
|
||||
|
||||
@@ -22,24 +22,24 @@ namespace ck {
|
||||
// 2. CThreadBuffer is StaticBuffer
|
||||
// Also assume:
|
||||
// M0 = N0 = 2. It will do 2x2 pipelined read and fma (ABBA optimization)
|
||||
template <index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AKMBlockDesc,
|
||||
typename BKNBlockDesc,
|
||||
index_t M1PerThreadM11,
|
||||
index_t N1PerThreadN11,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM100,
|
||||
index_t M1N1ThreadClusterN100,
|
||||
index_t M1N1ThreadClusterM101,
|
||||
index_t M1N1ThreadClusterN101,
|
||||
index_t AThreadCopyScalarPerVector_M11,
|
||||
index_t BThreadCopyScalarPerVector_N11,
|
||||
typename std::enable_if<AKMBlockDesc::IsKnownAtCompileTime() &&
|
||||
BKNBlockDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
template <
|
||||
index_t BlockSize,
|
||||
typename FloatA,
|
||||
typename FloatB,
|
||||
typename FloatC,
|
||||
typename AKMBlockDesc,
|
||||
typename BKNBlockDesc,
|
||||
index_t M1PerThreadM11,
|
||||
index_t N1PerThreadN11,
|
||||
index_t KPerThread,
|
||||
index_t M1N1ThreadClusterM100,
|
||||
index_t M1N1ThreadClusterN100,
|
||||
index_t M1N1ThreadClusterM101,
|
||||
index_t M1N1ThreadClusterN101,
|
||||
index_t AThreadCopyScalarPerVector_M11,
|
||||
index_t BThreadCopyScalarPerVector_N11,
|
||||
typename enable_if<AKMBlockDesc::IsKnownAtCompileTime() && BKNBlockDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v2r2_pipeline_2x2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
|
||||
@@ -38,9 +38,9 @@ template <index_t BlockSize,
|
||||
// BM10BN10ThreadClusterBN101, ...>
|
||||
index_t AThreadCopyScalarPerVector_BM11,
|
||||
index_t BThreadCopyScalarPerVector_BN11,
|
||||
typename std::enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
typename enable_if<ABlockDesc_BK0_BM_BK1::IsKnownAtCompileTime() &&
|
||||
BBlockDesc_BK0_BN_BK1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct BlockwiseGemmDlops_A_BK0_BM_BK1_B_BK0_BN_BK1_C_BM0_BM1_BN0_BN1_pipeline_BM0_2_BN0_2
|
||||
{
|
||||
using AIndex = MultiIndex<3>;
|
||||
|
||||
@@ -21,10 +21,10 @@ template <typename FloatA,
|
||||
typename TKLengths,
|
||||
typename TMLengths,
|
||||
typename TNLengths,
|
||||
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1
|
||||
{
|
||||
__device__ constexpr ThreadwiseGemmDlops_km0m1_kn0n1_m0m1n0n1()
|
||||
@@ -123,10 +123,10 @@ template <typename FloatA,
|
||||
typename TKLengths,
|
||||
typename TMLengths,
|
||||
typename TNLengths,
|
||||
typename std::enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
typename enable_if<AThreadDesc_TK0_TM0_TM1_TK1::IsKnownAtCompileTime() &&
|
||||
BThreadDesc_TK0_TN0_TN1_TK1::IsKnownAtCompileTime() &&
|
||||
CThreadDesc_TM0_TM1_TN0_TN1::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1
|
||||
{
|
||||
__device__ constexpr ThreadwiseContractionDlops_A_TK0_TM0_TM1_TK1_B_TK0_TN0_TN1_TK1_C_TM0_TM1_TN0_TN1()
|
||||
|
||||
@@ -19,9 +19,9 @@ template <typename FloatA,
|
||||
typename CDesc,
|
||||
index_t H,
|
||||
index_t W,
|
||||
typename std::enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
typename enable_if<ADesc::IsKnownAtCompileTime() && BDesc::IsKnownAtCompileTime() &&
|
||||
CDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseGemmDlops_km_kn_mn_v3
|
||||
{
|
||||
template <typename ABuffer,
|
||||
|
||||
@@ -15,7 +15,7 @@ namespace ck {
|
||||
template <typename Data,
|
||||
typename Desc,
|
||||
typename SliceLengths,
|
||||
typename std::enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
typename enable_if<Desc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceSet_v1
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
@@ -57,7 +57,7 @@ template <typename SrcData,
|
||||
InMemoryDataOperationEnum_t DstInMemOp,
|
||||
index_t DstScalarStrideInVector,
|
||||
bool DstResetCoordinateAfterRun,
|
||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v1r3
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
@@ -373,7 +373,7 @@ template <typename SrcData,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
bool SrcResetCoordinateAfterRun,
|
||||
typename std::enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
typename enable_if<DstDesc::IsKnownAtCompileTime(), bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v2
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
@@ -1261,18 +1261,17 @@ struct ThreadwiseTensorSliceTransfer_v3
|
||||
// 3. DstOriginIdx is known at compile-time
|
||||
// 4. use direct address calculation
|
||||
// 3. vector access on src
|
||||
template <
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
index_t SrcVectorDim,
|
||||
index_t SrcScalarPerVector,
|
||||
index_t SrcScalarStrideInVector,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v4
|
||||
{
|
||||
static constexpr index_t nDim = SliceLengths::Size();
|
||||
|
||||
@@ -621,17 +621,16 @@ struct ThreadwiseTensorSliceTransfer_v3r1
|
||||
// 3. DstOriginIdx is known at compile-time
|
||||
// 4. use direct address calculation
|
||||
// 3. vector access on src
|
||||
template <
|
||||
typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
typename SrcVectorTensorLengths,
|
||||
typename SrcVectorTensorContiguousDimOrder,
|
||||
typename std::enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
template <typename SrcData,
|
||||
typename DstData,
|
||||
typename SrcDesc,
|
||||
typename DstDesc,
|
||||
typename SliceLengths,
|
||||
typename DimAccessOrder,
|
||||
typename SrcVectorTensorLengths,
|
||||
typename SrcVectorTensorContiguousDimOrder,
|
||||
typename enable_if<SrcDesc::IsKnownAtCompileTime() && DstDesc::IsKnownAtCompileTime(),
|
||||
bool>::type = false>
|
||||
struct ThreadwiseTensorSliceTransfer_v4r1
|
||||
{
|
||||
static constexpr auto I0 = Number<0>{};
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
#define CK_C_STYLE_POINTER_CAST_HPP
|
||||
|
||||
#include "type.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <typename PY,
|
||||
typename PX,
|
||||
typename std::enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
|
||||
typename enable_if<is_pointer_v<PY> && is_pointer_v<PX>, bool>::type = false>
|
||||
__host__ __device__ PY c_style_pointer_cast(PX p_x)
|
||||
{
|
||||
#pragma clang diagnostic push
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
#include "functional2.hpp"
|
||||
#include "functional3.hpp"
|
||||
#include "functional4.hpp"
|
||||
#include "enable_if.hpp"
|
||||
#include "integral_constant.hpp"
|
||||
#include "math.hpp"
|
||||
#include "number.hpp"
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
#include "amd_buffer_addressing.hpp"
|
||||
#include "c_style_pointer_cast.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -38,7 +39,7 @@ struct DynamicBuffer
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
typename enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
@@ -93,7 +94,7 @@ struct DynamicBuffer
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename std::enable_if<
|
||||
typename enable_if<
|
||||
is_same<typename scalar_type<remove_cv_t<remove_reference_t<X>>>::type,
|
||||
typename scalar_type<remove_cv_t<remove_reference_t<T>>>::type>::value,
|
||||
bool>::type = false>
|
||||
|
||||
13
composable_kernel/include/utility/enable_if.hpp
Normal file
13
composable_kernel/include/utility/enable_if.hpp
Normal file
@@ -0,0 +1,13 @@
|
||||
#ifndef CK_ENABLE_IF_HPP
|
||||
#define CK_ENABLE_IF_HPP
|
||||
|
||||
namespace ck {
|
||||
|
||||
template <bool B, typename T = void>
|
||||
using enable_if = std::enable_if<B, T>;
|
||||
|
||||
template <bool B, typename T = void>
|
||||
using enable_if_t = typename std::enable_if<B, T>::type;
|
||||
|
||||
} // namespace ck
|
||||
#endif
|
||||
@@ -5,6 +5,7 @@
|
||||
#include "integral_constant.hpp"
|
||||
#include "number.hpp"
|
||||
#include "type.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
namespace math {
|
||||
@@ -184,9 +185,7 @@ __host__ __device__ constexpr auto gcd(Number<X>, Number<Y>)
|
||||
return Number<r>{};
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr auto gcd(X x, Ys... ys)
|
||||
{
|
||||
return gcd(x, gcd(ys...));
|
||||
@@ -199,9 +198,7 @@ __host__ __device__ constexpr auto lcm(X x, Y y)
|
||||
return (x * y) / gcd(x, y);
|
||||
}
|
||||
|
||||
template <typename X,
|
||||
typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
template <typename X, typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr auto lcm(X x, Ys... ys)
|
||||
{
|
||||
return lcm(x, lcm(ys...));
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
#include "integral_constant.hpp"
|
||||
#include "sequence.hpp"
|
||||
#include "type.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -20,10 +21,9 @@ struct TupleElement
|
||||
{
|
||||
__host__ __device__ constexpr TupleElement() = default;
|
||||
|
||||
template <
|
||||
typename T,
|
||||
typename std::enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
|
||||
bool>::type = false>
|
||||
template <typename T,
|
||||
typename enable_if<!is_same<remove_reference_t<remove_cv_t<T>>, TupleElement>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr TupleElement(T&& v) : mData(std::forward<T>(v))
|
||||
{
|
||||
}
|
||||
@@ -58,17 +58,16 @@ struct TupleImpl<Sequence<Is...>, Xs...> : TupleElement<TupleElementKey<Is>, Xs>
|
||||
{
|
||||
__host__ __device__ constexpr TupleImpl() = default;
|
||||
|
||||
template <
|
||||
typename Y,
|
||||
typename std::enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
|
||||
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
|
||||
bool>::type = false>
|
||||
template <typename Y,
|
||||
typename enable_if<sizeof...(Is) == 1 && sizeof...(Xs) == 1 &&
|
||||
!is_same<remove_reference_t<remove_cv_t<Y>>, TupleImpl>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr TupleImpl(Y&& y)
|
||||
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Y>(y))...
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... Ys, typename std::enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
template <typename... Ys, typename enable_if<sizeof...(Ys) >= 2, bool>::type = false>
|
||||
__host__ __device__ constexpr TupleImpl(Ys&&... ys)
|
||||
: TupleElement<TupleElementKey<Is>, Xs>(std::forward<Ys>(ys))...
|
||||
{
|
||||
@@ -102,16 +101,16 @@ struct Tuple : detail::TupleImpl<typename arithmetic_sequence_gen<0, sizeof...(X
|
||||
__host__ __device__ constexpr Tuple() = default;
|
||||
|
||||
template <typename Y,
|
||||
typename std::enable_if<
|
||||
sizeof...(Xs) == 1 && !is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
|
||||
bool>::type = false>
|
||||
typename enable_if<sizeof...(Xs) == 1 &&
|
||||
!is_same<remove_reference_t<remove_cv_t<Y>>, Tuple>::value,
|
||||
bool>::type = false>
|
||||
__host__ __device__ constexpr Tuple(Y&& y) : base(std::forward<Y>(y))
|
||||
{
|
||||
}
|
||||
|
||||
template <typename... Ys,
|
||||
typename std::enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2,
|
||||
bool>::type = false>
|
||||
typename enable_if<sizeof...(Ys) == sizeof...(Xs) && sizeof...(Ys) >= 2, bool>::type =
|
||||
false>
|
||||
__host__ __device__ constexpr Tuple(Ys&&... ys) : base(std::forward<Ys>(ys)...)
|
||||
{
|
||||
}
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
#define CK_TYPE_HPP
|
||||
|
||||
#include "integral_constant.hpp"
|
||||
#include "enable_if.hpp"
|
||||
|
||||
namespace ck {
|
||||
|
||||
@@ -39,9 +40,7 @@ struct is_known_at_compile_time<integral_constant<T, X>>
|
||||
static constexpr bool value = true;
|
||||
};
|
||||
|
||||
template <typename Y,
|
||||
typename X,
|
||||
typename std::enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
|
||||
template <typename Y, typename X, typename enable_if<sizeof(X) == sizeof(Y), bool>::type = false>
|
||||
__host__ __device__ constexpr Y as_type(X x)
|
||||
{
|
||||
union AsType
|
||||
|
||||
Reference in New Issue
Block a user