Remove std:: references

This commit is contained in:
danthe3rd
2022-10-13 07:36:21 +00:00
parent 887f1df16b
commit 4fd03dee05
8 changed files with 102 additions and 72 deletions

View File

@@ -7,7 +7,6 @@
#include "cutlass/matrix_shape.h"
#include "gemm_kernel_utils.h"
namespace {
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
@@ -69,9 +68,7 @@ struct RegisterOps {
BASE::iterateRows(
lane_offset,
[&](int accum_m) {
// XXX: This is a __host__ function, so we need to workaround
// max = -std::numeric_limits<accum_t>::infinity();
max = gemm_kernel_utils::TypeTraits<accum_t>::kMinusInf;
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
},
[&](int accum_m, int accum_n, int idx) {
if (kFullColumns || accum_n < max_col) {
@@ -140,7 +137,9 @@ struct AttentionScalingCoefsUpdaterSm80
accum_t,
kWarpSize> {
static_assert(
std::is_same<typename T::Layout, cutlass::layout::RowMajor>::value);
cutlass::platform::
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
@@ -208,10 +207,6 @@ struct AttentionScalingCoefsUpdaterSm80
}
};
// cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<cutlass::MatrixShape<32,
// 32>, float, cutlass::layout::RowMajor, cutlass::gemm::GemmShape<16, 16, 4>,
// cutlass::MatrixShape<1, 1>> See
// cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h
template <typename T, typename accum_t, int kWarpSize>
struct AttentionScalingCoefsUpdaterVolta
: RegisterOps<
@@ -220,7 +215,9 @@ struct AttentionScalingCoefsUpdaterVolta
accum_t,
kWarpSize> {
static_assert(
std::is_same<typename T::Layout, cutlass::layout::RowMajor>::value);
cutlass::platform::
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
using Policy = typename T::Policy;
using InstructionShape = typename T::InstructionShape;
@@ -341,9 +338,9 @@ struct AttentionScalingCoefsUpdaterSimt
using Delta = typename T::Delta;
using Shape = typename T::Shape;
static_assert(
std::is_same<typename T::Layout, cutlass::layout::RowMajor>::value);
static_assert(
std::is_same<typename T::Iterations, typename T::Iterations>::value);
cutlass::platform::
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
"only RowMajor is supported");
template <typename DT, typename F>
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) {
@@ -392,9 +389,11 @@ struct AttentionScalingCoefsUpdaterSimt
int8_t lane_id,
int8_t warp_id,
typename T::TensorCoord const& tile_offset) {
static_assert(std::is_same<
typename Policy::LaneLayout,
cutlass::layout::RowMajorInterleaved<1>>::value);
static_assert(
cutlass::platform::is_same<
typename Policy::LaneLayout,
cutlass::layout::RowMajorInterleaved<1>>::value,
"");
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) *

View File

@@ -114,7 +114,7 @@ struct FindDefaultMma<
InstructionShape,
kStages,
Operator,
typename std::enable_if<(kAlignmentA > 1)>::type> {
typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> {
using LayoutC = layout::RowMajor;
using OperatorClass = arch::OpClassTensorOp;
using ArchTag = arch::Sm80;

View File

@@ -38,9 +38,12 @@ struct MakeCustomMma<
SharedMemoryClear>,
kMaxK> {
// Reduce the number of stages if we don't need that many
static int constexpr kStages = kMaxK == std::numeric_limits<int>::max()
static int constexpr kStages =
kMaxK == cutlass::platform::numeric_limits<int>::max()
? Stages
: std::min(Stages, (kMaxK + int(Shape::kK) - 1) / int(Shape::kK));
: cutlass::const_min(
Stages,
(kMaxK + int(Shape::kK) - 1) / int(Shape::kK));
using Mma = cutlass::gemm::threadblock::CustomMmaMultistage<
Shape,
IteratorA,

View File

@@ -88,7 +88,7 @@ template <
/// Use zfill or predicate for out-of-bound cp.async
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
/// Upper boundon the K dimension
int kMaxK = std::numeric_limits<int>::max(),
int kMaxK = cutlass::platform::numeric_limits<int>::max(),
/// Used for partial specialization
typename Enable = bool>
class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {

View File

@@ -67,26 +67,39 @@
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
TORCH_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
#define XFORMERS_CHECK TORCH_CHECK
#else
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if(!(uint64_t(PTR) % ALIGNMENT == 0)) {\
std::cerr << #PTR " is not correctly aligned\n";\
return false;\
#elif defined(__CUDACC_RTC__)
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
return false; \
}
#define XFORMERS_CHECK(COND, ERR) if (!(COND)) {\
std::cerr << #COND " failed\n";\
return false;\
#define XFORMERS_CHECK(COND, ERR) \
if (!(COND)) { \
return false; \
}
#else
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
std::cerr << #PTR " is not correctly aligned\n"; \
return false; \
}
#define XFORMERS_CHECK(COND, ERR) \
if (!(COND)) { \
std::cerr << #COND " failed\n"; \
return false; \
}
#endif
#define ASSIGN_CHECK_OVERFLOW(A, B) \
{ \
A = B; \
TORCH_CHECK(B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
#define ASSIGN_CHECK_OVERFLOW(A, B) \
{ \
A = B; \
TORCH_CHECK( \
B < cutlass::platform::numeric_limits<decltype(A)>::max(), \
#B " overflows"); \
}
namespace gemm_kernel_utils {
#ifdef HAS_PYTORCH
template <typename scalar_t>
struct TypeTraits;
@@ -94,7 +107,6 @@ template <>
struct TypeTraits<cutlass::half_t> {
using scalar_t = cutlass::half_t;
#ifdef HAS_PYTORCH
static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Half;
}
@@ -106,14 +118,12 @@ struct TypeTraits<cutlass::half_t> {
tensor.sizes().data(),
tensor.strides().data());
}
#endif
};
template <>
struct TypeTraits<cutlass::bfloat16_t> {
using scalar_t = cutlass::bfloat16_t;
#ifdef HAS_PYTORCH
static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::BFloat16;
}
@@ -125,16 +135,12 @@ struct TypeTraits<cutlass::bfloat16_t> {
tensor.sizes().data(),
tensor.strides().data());
}
#endif
};
template <>
struct TypeTraits<float> {
using scalar_t = float;
static constexpr scalar_t kInf = std::numeric_limits<scalar_t>::infinity();
static constexpr scalar_t kMinusInf = -std::numeric_limits<scalar_t>::infinity();
#ifdef HAS_PYTORCH
static constexpr __host__ at::ScalarType atScalarType() {
return at::ScalarType::Float;
}
@@ -143,8 +149,8 @@ struct TypeTraits<float> {
at::Tensor const& tensor) {
return tensor.packed_accessor32<scalar_t, nDim>();
}
#endif
};
#endif
template <typename integer>
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
@@ -172,7 +178,8 @@ template <typename ArchTag>
struct DefaultGemmType<
ArchTag,
float,
typename std::enable_if<ArchTag::kMinComputeCapability >= 80>::type> {
typename cutlass::platform::enable_if<
ArchTag::kMinComputeCapability >= 80>::type> {
static constexpr int ThreadK = 32;
static constexpr int WarpK = 32;
static constexpr int kMinimumAlignment = 4;
@@ -186,7 +193,7 @@ template <typename ArchTag, typename scalar_t>
struct DefaultGemmType<
ArchTag,
scalar_t,
typename std::enable_if<
typename cutlass::platform::enable_if<
ArchTag::kMinComputeCapability >= 75 &&
cutlass::sizeof_bits<scalar_t>::value == 16>::type> {
static constexpr int ThreadK = 32;
@@ -216,17 +223,19 @@ struct call_conditional;
template <typename TA, typename TB>
struct call_conditional<true, TA, TB> {
template <typename... Args>
static CUTLASS_DEVICE auto apply(TA ta, TB tb, Args&&... args) -> decltype(ta(std::forward<Args>(args)...)) {
return ta(std::forward<Args>(args)...);
template <typename Arg>
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
-> decltype(ta(arg)) {
return ta(arg);
}
};
template <typename TA, typename TB>
struct call_conditional<false, TA, TB> {
template <typename... Args>
static CUTLASS_DEVICE auto apply(TA ta, TB tb, Args&&... args) -> decltype(tb(std::forward<Args>(args)...)) {
return tb(std::forward<Args>(args)...);
template <typename Arg>
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
-> decltype(tb(arg)) {
return tb(arg);
}
};

View File

@@ -44,7 +44,8 @@ namespace {
template <typename scalar_t, typename Arch>
constexpr int getWarpsPerSm() {
return (
Arch::kMinComputeCapability >= 80 && !std::is_same<scalar_t, float>::value
Arch::kMinComputeCapability >= 80 &&
!cutlass::platform::is_same<scalar_t, float>::value
? 16
: 12);
}
@@ -75,8 +76,8 @@ struct AttentionKernel {
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
cutlass::sizeof_bits<scalar_t>::value == 16;
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
static constexpr bool kNeedsOutputAccumulatorBuffer =
!kKeepOutputInRF && !std::is_same<output_accum_t, output_t>::value;
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
!cutlass::platform::is_same<output_accum_t, output_t>::value;
static_assert(kQueriesPerBlock % 32 == 0, "");
static_assert(kKeysPerBlock % 32 == 0, "");
@@ -413,7 +414,7 @@ struct AttentionKernel {
}
};
using SharedStorage = typename std::conditional<
using SharedStorage = typename cutlass::platform::conditional<
kSingleValueIteration || kKeepOutputInRF,
SharedStorageEpilogueAtEnd,
SharedStorageEpilogueInLoop>::type;
@@ -446,8 +447,9 @@ struct AttentionKernel {
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = accum_t(0);
m_prime[thread_id()] = gemm_kernel_utils::TypeTraits<accum_t>::kMinusInf;
mi[thread_id()] = gemm_kernel_utils::TypeTraits<accum_t>::kMinusInf;
m_prime[thread_id()] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
}
typename MM1::Mma::FragmentC accum_o;
accum_o.clear();
@@ -580,7 +582,8 @@ struct AttentionKernel {
},
[&](int accum_m, int accum_n, int idx) {
if (accum_n > last_col) {
accum[idx] = gemm_kernel_utils::TypeTraits<accum_t>::kMinusInf;
accum[idx] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
}
},
[&](int accum_m) {});
@@ -609,7 +612,7 @@ struct AttentionKernel {
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset,
1.0f / std::sqrt(float(p.head_dim)));
1.0f / cutlass::fast_sqrt(float(p.head_dim)));
}));
}));
@@ -683,7 +686,7 @@ struct AttentionKernel {
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp = typename cutlass::epilogue::
thread::MemoryEfficientAttentionNormalize<
typename std::conditional<
typename cutlass::platform::conditional<
kIsLast,
output_t,
output_accum_t>::type,
@@ -699,7 +702,7 @@ struct AttentionKernel {
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename std::conditional<
typename cutlass::platform::conditional<
kIsLast,
typename MM1::OutputTileIterator,
typename MM1::OutputTileIteratorAccum>::type,
@@ -787,11 +790,11 @@ struct AttentionKernel {
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
if (thread_id() < p.num_queries) {
p.logsumexp_ptr[thread_id()] =
accum_t(mi[thread_id()]) + std::log(accum_t(s_prime[thread_id()]));
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
cutlass::fast_log(accum_t(s_prime[thread_id()]));
} else if (thread_id() < lse_dim) {
p.logsumexp_ptr[thread_id()] =
gemm_kernel_utils::TypeTraits<accum_t>::kInf;
cutlass::platform::numeric_limits<accum_t>::infinity();
}
}
}

View File

@@ -1031,14 +1031,6 @@ class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
}
};
namespace {
template <typename A, typename B>
struct AssertIsSame {
static_assert(std::is_same<A, B>::value);
using CHECK = bool;
};
} // namespace
template <
typename WarpShape,
typename InstructionShape,
@@ -1264,7 +1256,8 @@ struct DefaultMmaFromSharedMemory<
static int constexpr kMaxK = AccumulatorSharedStorage_::Shape::kN;
// Reduce the number of stages if we don't need that many
static int constexpr kStagesMax = (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
static int constexpr kStagesMax =
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
static int constexpr kStages = cutlass::const_min(Stages, kStagesMax);
using IteratorB =
@@ -1630,7 +1623,7 @@ struct B2bGemm<
if (rowIdx == 1) {
lse_prefetched[colIdx] = accum_n < lse_extent
? lse[accum_n]
: std::numeric_limits<accum_t>::infinity();
: platform::numeric_limits<accum_t>::infinity();
}
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
++colIdx;
@@ -1770,7 +1763,7 @@ struct B2bGemm<
if (rowIdx == 1) {
lse_prefetched[colIdx] = accum_n < lse_extent
? lse[accum_n]
: std::numeric_limits<accum_t>::infinity();
: platform::numeric_limits<accum_t>::infinity();
}
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
++colIdx;

View File

@@ -574,6 +574,21 @@ using std::is_trivially_copyable;
#endif
//-----------------------------------------------------------------------------
// bit_cast <bit>
//-----------------------------------------------------------------------------
template< class To, class From >
constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& from ) noexcept;
template <class To, class From>
constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& src) noexcept
{
static_assert(sizeof(To) == sizeof(From), "sizes must match");
return reinterpret_cast<To const &>(src);
}
//-----------------------------------------------------------------------------
// Alignment and layout utilities
//-----------------------------------------------------------------------------
@@ -865,5 +880,13 @@ struct numeric_limits<uint8_t> {
static constexpr bool is_integer = true;
};
template <>
struct numeric_limits<float> {
CUTLASS_HOST_DEVICE
static constexpr float infinity() noexcept { return bit_cast<float, int32_t>(0x7f800000);}
static constexpr bool is_integer = false;
static constexpr bool has_infinity = true;
};
} // namespace platform
} // namespace cutlass