From 4fd03dee05caa4be4912f0e1cd335568c72b46db Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Thu, 13 Oct 2022 07:36:21 +0000 Subject: [PATCH] Remove std:: references --- .../attention_scaling_coefs_updater.h | 31 +++++---- .../find_default_mma.h | 2 +- .../gemm/custom_mma.h | 7 +- .../gemm/custom_mma_multistage.h | 2 +- .../gemm_kernel_utils.h | 65 +++++++++++-------- .../kernel_forward.h | 29 +++++---- .../mma_from_smem.h | 15 ++--- include/cutlass/platform/platform.h | 23 +++++++ 8 files changed, 102 insertions(+), 72 deletions(-) diff --git a/examples/42_fused_multi_head_attention/attention_scaling_coefs_updater.h b/examples/42_fused_multi_head_attention/attention_scaling_coefs_updater.h index d74b5a989..c8e0df3b7 100644 --- a/examples/42_fused_multi_head_attention/attention_scaling_coefs_updater.h +++ b/examples/42_fused_multi_head_attention/attention_scaling_coefs_updater.h @@ -7,7 +7,6 @@ #include "cutlass/matrix_shape.h" #include "gemm_kernel_utils.h" - namespace { static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { @@ -69,9 +68,7 @@ struct RegisterOps { BASE::iterateRows( lane_offset, [&](int accum_m) { - // XXX: This is a __host__ function, so we need to workaround - // max = -std::numeric_limits::infinity(); - max = gemm_kernel_utils::TypeTraits::kMinusInf; + max = -cutlass::platform::numeric_limits::infinity(); }, [&](int accum_m, int accum_n, int idx) { if (kFullColumns || accum_n < max_col) { @@ -140,7 +137,9 @@ struct AttentionScalingCoefsUpdaterSm80 accum_t, kWarpSize> { static_assert( - std::is_same::value); + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); using Policy = typename T::Policy; using InstructionShape = typename T::InstructionShape; @@ -208,10 +207,6 @@ struct AttentionScalingCoefsUpdaterSm80 } }; -// cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator, float, cutlass::layout::RowMajor, cutlass::gemm::GemmShape<16, 16, 4>, -// cutlass::MatrixShape<1, 1>> See -// cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h template struct AttentionScalingCoefsUpdaterVolta : RegisterOps< @@ -220,7 +215,9 @@ struct AttentionScalingCoefsUpdaterVolta accum_t, kWarpSize> { static_assert( - std::is_same::value); + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); using Policy = typename T::Policy; using InstructionShape = typename T::InstructionShape; @@ -341,9 +338,9 @@ struct AttentionScalingCoefsUpdaterSimt using Delta = typename T::Delta; using Shape = typename T::Shape; static_assert( - std::is_same::value); - static_assert( - std::is_same::value); + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); template CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { @@ -392,9 +389,11 @@ struct AttentionScalingCoefsUpdaterSimt int8_t lane_id, int8_t warp_id, typename T::TensorCoord const& tile_offset) { - static_assert(std::is_same< - typename Policy::LaneLayout, - cutlass::layout::RowMajorInterleaved<1>>::value); + static_assert( + cutlass::platform::is_same< + typename Policy::LaneLayout, + cutlass::layout::RowMajorInterleaved<1>>::value, + ""); typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) * diff --git a/examples/42_fused_multi_head_attention/find_default_mma.h b/examples/42_fused_multi_head_attention/find_default_mma.h index a9f4aae92..9cd64d6da 100644 --- a/examples/42_fused_multi_head_attention/find_default_mma.h +++ b/examples/42_fused_multi_head_attention/find_default_mma.h @@ -114,7 +114,7 @@ struct FindDefaultMma< InstructionShape, kStages, Operator, - typename std::enable_if<(kAlignmentA > 1)>::type> { + typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> { using LayoutC = layout::RowMajor; using OperatorClass = arch::OpClassTensorOp; using ArchTag = arch::Sm80; diff --git a/examples/42_fused_multi_head_attention/gemm/custom_mma.h b/examples/42_fused_multi_head_attention/gemm/custom_mma.h index 772b6ac5a..c0f1cd500 100644 --- a/examples/42_fused_multi_head_attention/gemm/custom_mma.h +++ b/examples/42_fused_multi_head_attention/gemm/custom_mma.h @@ -38,9 +38,12 @@ struct MakeCustomMma< SharedMemoryClear>, kMaxK> { // Reduce the number of stages if we don't need that many - static int constexpr kStages = kMaxK == std::numeric_limits::max() + static int constexpr kStages = + kMaxK == cutlass::platform::numeric_limits::max() ? Stages - : std::min(Stages, (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); + : cutlass::const_min( + Stages, + (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); using Mma = cutlass::gemm::threadblock::CustomMmaMultistage< Shape, IteratorA, diff --git a/examples/42_fused_multi_head_attention/gemm/custom_mma_multistage.h b/examples/42_fused_multi_head_attention/gemm/custom_mma_multistage.h index 6d6ce7b9d..fefee4308 100644 --- a/examples/42_fused_multi_head_attention/gemm/custom_mma_multistage.h +++ b/examples/42_fused_multi_head_attention/gemm/custom_mma_multistage.h @@ -88,7 +88,7 @@ template < /// Use zfill or predicate for out-of-bound cp.async SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, /// Upper boundon the K dimension - int kMaxK = std::numeric_limits::max(), + int kMaxK = cutlass::platform::numeric_limits::max(), /// Used for partial specialization typename Enable = bool> class CustomMmaMultistage : public CustomMmaBase { diff --git a/examples/42_fused_multi_head_attention/gemm_kernel_utils.h b/examples/42_fused_multi_head_attention/gemm_kernel_utils.h index 2bd515bf5..bba9e6fcc 100644 --- a/examples/42_fused_multi_head_attention/gemm_kernel_utils.h +++ b/examples/42_fused_multi_head_attention/gemm_kernel_utils.h @@ -67,26 +67,39 @@ #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ TORCH_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") #define XFORMERS_CHECK TORCH_CHECK -#else -#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ - if(!(uint64_t(PTR) % ALIGNMENT == 0)) {\ - std::cerr << #PTR " is not correctly aligned\n";\ - return false;\ +#elif defined(__CUDACC_RTC__) +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + return false; \ } -#define XFORMERS_CHECK(COND, ERR) if (!(COND)) {\ - std::cerr << #COND " failed\n";\ - return false;\ +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + return false; \ + } +#else +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + std::cerr << #PTR " is not correctly aligned\n"; \ + return false; \ + } +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + std::cerr << #COND " failed\n"; \ + return false; \ } #endif -#define ASSIGN_CHECK_OVERFLOW(A, B) \ - { \ - A = B; \ - TORCH_CHECK(B < std::numeric_limits::max(), #B " overflows"); \ +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + TORCH_CHECK( \ + B < cutlass::platform::numeric_limits::max(), \ + #B " overflows"); \ } namespace gemm_kernel_utils { +#ifdef HAS_PYTORCH template struct TypeTraits; @@ -94,7 +107,6 @@ template <> struct TypeTraits { using scalar_t = cutlass::half_t; -#ifdef HAS_PYTORCH static constexpr __host__ at::ScalarType atScalarType() { return at::ScalarType::Half; } @@ -106,14 +118,12 @@ struct TypeTraits { tensor.sizes().data(), tensor.strides().data()); } -#endif }; template <> struct TypeTraits { using scalar_t = cutlass::bfloat16_t; -#ifdef HAS_PYTORCH static constexpr __host__ at::ScalarType atScalarType() { return at::ScalarType::BFloat16; } @@ -125,16 +135,12 @@ struct TypeTraits { tensor.sizes().data(), tensor.strides().data()); } -#endif }; template <> struct TypeTraits { using scalar_t = float; - static constexpr scalar_t kInf = std::numeric_limits::infinity(); - static constexpr scalar_t kMinusInf = -std::numeric_limits::infinity(); -#ifdef HAS_PYTORCH static constexpr __host__ at::ScalarType atScalarType() { return at::ScalarType::Float; } @@ -143,8 +149,8 @@ struct TypeTraits { at::Tensor const& tensor) { return tensor.packed_accessor32(); } -#endif }; +#endif template constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { @@ -172,7 +178,8 @@ template struct DefaultGemmType< ArchTag, float, - typename std::enable_if= 80>::type> { + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 80>::type> { static constexpr int ThreadK = 32; static constexpr int WarpK = 32; static constexpr int kMinimumAlignment = 4; @@ -186,7 +193,7 @@ template struct DefaultGemmType< ArchTag, scalar_t, - typename std::enable_if< + typename cutlass::platform::enable_if< ArchTag::kMinComputeCapability >= 75 && cutlass::sizeof_bits::value == 16>::type> { static constexpr int ThreadK = 32; @@ -216,17 +223,19 @@ struct call_conditional; template struct call_conditional { - template - static CUTLASS_DEVICE auto apply(TA ta, TB tb, Args&&... args) -> decltype(ta(std::forward(args)...)) { - return ta(std::forward(args)...); + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(ta(arg)) { + return ta(arg); } }; template struct call_conditional { - template - static CUTLASS_DEVICE auto apply(TA ta, TB tb, Args&&... args) -> decltype(tb(std::forward(args)...)) { - return tb(std::forward(args)...); + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(tb(arg)) { + return tb(arg); } }; diff --git a/examples/42_fused_multi_head_attention/kernel_forward.h b/examples/42_fused_multi_head_attention/kernel_forward.h index 2b07e600a..fe820d4b8 100644 --- a/examples/42_fused_multi_head_attention/kernel_forward.h +++ b/examples/42_fused_multi_head_attention/kernel_forward.h @@ -44,7 +44,8 @@ namespace { template constexpr int getWarpsPerSm() { return ( - Arch::kMinComputeCapability >= 80 && !std::is_same::value + Arch::kMinComputeCapability >= 80 && + !cutlass::platform::is_same::value ? 16 : 12); } @@ -75,8 +76,8 @@ struct AttentionKernel { static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 && cutlass::sizeof_bits::value == 16; static constexpr bool kKeepOutputInRF = kSingleValueIteration; - static constexpr bool kNeedsOutputAccumulatorBuffer = - !kKeepOutputInRF && !std::is_same::value; + static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && + !cutlass::platform::is_same::value; static_assert(kQueriesPerBlock % 32 == 0, ""); static_assert(kKeysPerBlock % 32 == 0, ""); @@ -413,7 +414,7 @@ struct AttentionKernel { } }; - using SharedStorage = typename std::conditional< + using SharedStorage = typename cutlass::platform::conditional< kSingleValueIteration || kKeepOutputInRF, SharedStorageEpilogueAtEnd, SharedStorageEpilogueInLoop>::type; @@ -446,8 +447,9 @@ struct AttentionKernel { static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); if (thread_id() < kQueriesPerBlock) { s_prime[thread_id()] = accum_t(0); - m_prime[thread_id()] = gemm_kernel_utils::TypeTraits::kMinusInf; - mi[thread_id()] = gemm_kernel_utils::TypeTraits::kMinusInf; + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); } typename MM1::Mma::FragmentC accum_o; accum_o.clear(); @@ -580,7 +582,8 @@ struct AttentionKernel { }, [&](int accum_m, int accum_n, int idx) { if (accum_n > last_col) { - accum[idx] = gemm_kernel_utils::TypeTraits::kMinusInf; + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); } }, [&](int accum_m) {}); @@ -609,7 +612,7 @@ struct AttentionKernel { warp_id(), p.num_keys - iter_key_start, iteratorC_tile_offset, - 1.0f / std::sqrt(float(p.head_dim))); + 1.0f / cutlass::fast_sqrt(float(p.head_dim))); })); })); @@ -683,7 +686,7 @@ struct AttentionKernel { using ElementCompute = typename DefaultOp::ElementCompute; using EpilogueOutputOp = typename cutlass::epilogue:: thread::MemoryEfficientAttentionNormalize< - typename std::conditional< + typename cutlass::platform::conditional< kIsLast, output_t, output_accum_t>::type, @@ -699,7 +702,7 @@ struct AttentionKernel { typename DefaultEpilogue::Shape, typename MM1::Mma::Operator, DefaultEpilogue::kPartitionsK, - typename std::conditional< + typename cutlass::platform::conditional< kIsLast, typename MM1::OutputTileIterator, typename MM1::OutputTileIteratorAccum>::type, @@ -787,11 +790,11 @@ struct AttentionKernel { if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; if (thread_id() < p.num_queries) { - p.logsumexp_ptr[thread_id()] = - accum_t(mi[thread_id()]) + std::log(accum_t(s_prime[thread_id()])); + p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) + + cutlass::fast_log(accum_t(s_prime[thread_id()])); } else if (thread_id() < lse_dim) { p.logsumexp_ptr[thread_id()] = - gemm_kernel_utils::TypeTraits::kInf; + cutlass::platform::numeric_limits::infinity(); } } } diff --git a/examples/42_fused_multi_head_attention/mma_from_smem.h b/examples/42_fused_multi_head_attention/mma_from_smem.h index 9df003744..e610db3c7 100644 --- a/examples/42_fused_multi_head_attention/mma_from_smem.h +++ b/examples/42_fused_multi_head_attention/mma_from_smem.h @@ -1031,14 +1031,6 @@ class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< } }; -namespace { -template -struct AssertIsSame { - static_assert(std::is_same::value); - using CHECK = bool; -}; -} // namespace - template < typename WarpShape, typename InstructionShape, @@ -1264,7 +1256,8 @@ struct DefaultMmaFromSharedMemory< static int constexpr kMaxK = AccumulatorSharedStorage_::Shape::kN; // Reduce the number of stages if we don't need that many - static int constexpr kStagesMax = (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); + static int constexpr kStagesMax = + (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); using IteratorB = @@ -1630,7 +1623,7 @@ struct B2bGemm< if (rowIdx == 1) { lse_prefetched[colIdx] = accum_n < lse_extent ? lse[accum_n] - : std::numeric_limits::infinity(); + : platform::numeric_limits::infinity(); } accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); ++colIdx; @@ -1770,7 +1763,7 @@ struct B2bGemm< if (rowIdx == 1) { lse_prefetched[colIdx] = accum_n < lse_extent ? lse[accum_n] - : std::numeric_limits::infinity(); + : platform::numeric_limits::infinity(); } accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); ++colIdx; diff --git a/include/cutlass/platform/platform.h b/include/cutlass/platform/platform.h index 6b8a626f8..ced1bef21 100644 --- a/include/cutlass/platform/platform.h +++ b/include/cutlass/platform/platform.h @@ -574,6 +574,21 @@ using std::is_trivially_copyable; #endif + +//----------------------------------------------------------------------------- +// bit_cast +//----------------------------------------------------------------------------- + +template< class To, class From > +constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& from ) noexcept; + +template +constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& src) noexcept +{ + static_assert(sizeof(To) == sizeof(From), "sizes must match"); + return reinterpret_cast(src); +} + //----------------------------------------------------------------------------- // Alignment and layout utilities //----------------------------------------------------------------------------- @@ -865,5 +880,13 @@ struct numeric_limits { static constexpr bool is_integer = true; }; +template <> +struct numeric_limits { + CUTLASS_HOST_DEVICE + static constexpr float infinity() noexcept { return bit_cast(0x7f800000);} + static constexpr bool is_integer = false; + static constexpr bool has_infinity = true; +}; + } // namespace platform } // namespace cutlass