mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-12 17:25:45 +00:00
Remove std:: references
This commit is contained in:
@@ -7,7 +7,6 @@
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "gemm_kernel_utils.h"
|
||||
|
||||
|
||||
namespace {
|
||||
|
||||
static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) {
|
||||
@@ -69,9 +68,7 @@ struct RegisterOps {
|
||||
BASE::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
// XXX: This is a __host__ function, so we need to workaround
|
||||
// max = -std::numeric_limits<accum_t>::infinity();
|
||||
max = gemm_kernel_utils::TypeTraits<accum_t>::kMinusInf;
|
||||
max = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (kFullColumns || accum_n < max_col) {
|
||||
@@ -140,7 +137,9 @@ struct AttentionScalingCoefsUpdaterSm80
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
static_assert(
|
||||
std::is_same<typename T::Layout, cutlass::layout::RowMajor>::value);
|
||||
cutlass::platform::
|
||||
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
|
||||
"only RowMajor is supported");
|
||||
|
||||
using Policy = typename T::Policy;
|
||||
using InstructionShape = typename T::InstructionShape;
|
||||
@@ -208,10 +207,6 @@ struct AttentionScalingCoefsUpdaterSm80
|
||||
}
|
||||
};
|
||||
|
||||
// cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator<cutlass::MatrixShape<32,
|
||||
// 32>, float, cutlass::layout::RowMajor, cutlass::gemm::GemmShape<16, 16, 4>,
|
||||
// cutlass::MatrixShape<1, 1>> See
|
||||
// cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h
|
||||
template <typename T, typename accum_t, int kWarpSize>
|
||||
struct AttentionScalingCoefsUpdaterVolta
|
||||
: RegisterOps<
|
||||
@@ -220,7 +215,9 @@ struct AttentionScalingCoefsUpdaterVolta
|
||||
accum_t,
|
||||
kWarpSize> {
|
||||
static_assert(
|
||||
std::is_same<typename T::Layout, cutlass::layout::RowMajor>::value);
|
||||
cutlass::platform::
|
||||
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
|
||||
"only RowMajor is supported");
|
||||
|
||||
using Policy = typename T::Policy;
|
||||
using InstructionShape = typename T::InstructionShape;
|
||||
@@ -341,9 +338,9 @@ struct AttentionScalingCoefsUpdaterSimt
|
||||
using Delta = typename T::Delta;
|
||||
using Shape = typename T::Shape;
|
||||
static_assert(
|
||||
std::is_same<typename T::Layout, cutlass::layout::RowMajor>::value);
|
||||
static_assert(
|
||||
std::is_same<typename T::Iterations, typename T::Iterations>::value);
|
||||
cutlass::platform::
|
||||
is_same<typename T::Layout, cutlass::layout::RowMajor>::value,
|
||||
"only RowMajor is supported");
|
||||
|
||||
template <typename DT, typename F>
|
||||
CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) {
|
||||
@@ -392,9 +389,11 @@ struct AttentionScalingCoefsUpdaterSimt
|
||||
int8_t lane_id,
|
||||
int8_t warp_id,
|
||||
typename T::TensorCoord const& tile_offset) {
|
||||
static_assert(std::is_same<
|
||||
typename Policy::LaneLayout,
|
||||
cutlass::layout::RowMajorInterleaved<1>>::value);
|
||||
static_assert(
|
||||
cutlass::platform::is_same<
|
||||
typename Policy::LaneLayout,
|
||||
cutlass::layout::RowMajorInterleaved<1>>::value,
|
||||
"");
|
||||
typename Policy::LaneLayout lane_layout = Policy::get_lane_layout();
|
||||
|
||||
cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) *
|
||||
|
||||
@@ -114,7 +114,7 @@ struct FindDefaultMma<
|
||||
InstructionShape,
|
||||
kStages,
|
||||
Operator,
|
||||
typename std::enable_if<(kAlignmentA > 1)>::type> {
|
||||
typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> {
|
||||
using LayoutC = layout::RowMajor;
|
||||
using OperatorClass = arch::OpClassTensorOp;
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
@@ -38,9 +38,12 @@ struct MakeCustomMma<
|
||||
SharedMemoryClear>,
|
||||
kMaxK> {
|
||||
// Reduce the number of stages if we don't need that many
|
||||
static int constexpr kStages = kMaxK == std::numeric_limits<int>::max()
|
||||
static int constexpr kStages =
|
||||
kMaxK == cutlass::platform::numeric_limits<int>::max()
|
||||
? Stages
|
||||
: std::min(Stages, (kMaxK + int(Shape::kK) - 1) / int(Shape::kK));
|
||||
: cutlass::const_min(
|
||||
Stages,
|
||||
(kMaxK + int(Shape::kK) - 1) / int(Shape::kK));
|
||||
using Mma = cutlass::gemm::threadblock::CustomMmaMultistage<
|
||||
Shape,
|
||||
IteratorA,
|
||||
|
||||
@@ -88,7 +88,7 @@ template <
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone,
|
||||
/// Upper boundon the K dimension
|
||||
int kMaxK = std::numeric_limits<int>::max(),
|
||||
int kMaxK = cutlass::platform::numeric_limits<int>::max(),
|
||||
/// Used for partial specialization
|
||||
typename Enable = bool>
|
||||
class CustomMmaMultistage : public CustomMmaBase<Shape_, Policy_, Stages> {
|
||||
|
||||
@@ -67,26 +67,39 @@
|
||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||
TORCH_CHECK(uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
|
||||
#define XFORMERS_CHECK TORCH_CHECK
|
||||
#else
|
||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||
if(!(uint64_t(PTR) % ALIGNMENT == 0)) {\
|
||||
std::cerr << #PTR " is not correctly aligned\n";\
|
||||
return false;\
|
||||
#elif defined(__CUDACC_RTC__)
|
||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
|
||||
return false; \
|
||||
}
|
||||
#define XFORMERS_CHECK(COND, ERR) if (!(COND)) {\
|
||||
std::cerr << #COND " failed\n";\
|
||||
return false;\
|
||||
#define XFORMERS_CHECK(COND, ERR) \
|
||||
if (!(COND)) { \
|
||||
return false; \
|
||||
}
|
||||
#else
|
||||
#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
|
||||
if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \
|
||||
std::cerr << #PTR " is not correctly aligned\n"; \
|
||||
return false; \
|
||||
}
|
||||
#define XFORMERS_CHECK(COND, ERR) \
|
||||
if (!(COND)) { \
|
||||
std::cerr << #COND " failed\n"; \
|
||||
return false; \
|
||||
}
|
||||
#endif
|
||||
|
||||
#define ASSIGN_CHECK_OVERFLOW(A, B) \
|
||||
{ \
|
||||
A = B; \
|
||||
TORCH_CHECK(B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
|
||||
#define ASSIGN_CHECK_OVERFLOW(A, B) \
|
||||
{ \
|
||||
A = B; \
|
||||
TORCH_CHECK( \
|
||||
B < cutlass::platform::numeric_limits<decltype(A)>::max(), \
|
||||
#B " overflows"); \
|
||||
}
|
||||
|
||||
namespace gemm_kernel_utils {
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
template <typename scalar_t>
|
||||
struct TypeTraits;
|
||||
|
||||
@@ -94,7 +107,6 @@ template <>
|
||||
struct TypeTraits<cutlass::half_t> {
|
||||
using scalar_t = cutlass::half_t;
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
static constexpr __host__ at::ScalarType atScalarType() {
|
||||
return at::ScalarType::Half;
|
||||
}
|
||||
@@ -106,14 +118,12 @@ struct TypeTraits<cutlass::half_t> {
|
||||
tensor.sizes().data(),
|
||||
tensor.strides().data());
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<cutlass::bfloat16_t> {
|
||||
using scalar_t = cutlass::bfloat16_t;
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
static constexpr __host__ at::ScalarType atScalarType() {
|
||||
return at::ScalarType::BFloat16;
|
||||
}
|
||||
@@ -125,16 +135,12 @@ struct TypeTraits<cutlass::bfloat16_t> {
|
||||
tensor.sizes().data(),
|
||||
tensor.strides().data());
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
||||
template <>
|
||||
struct TypeTraits<float> {
|
||||
using scalar_t = float;
|
||||
static constexpr scalar_t kInf = std::numeric_limits<scalar_t>::infinity();
|
||||
static constexpr scalar_t kMinusInf = -std::numeric_limits<scalar_t>::infinity();
|
||||
|
||||
#ifdef HAS_PYTORCH
|
||||
static constexpr __host__ at::ScalarType atScalarType() {
|
||||
return at::ScalarType::Float;
|
||||
}
|
||||
@@ -143,8 +149,8 @@ struct TypeTraits<float> {
|
||||
at::Tensor const& tensor) {
|
||||
return tensor.packed_accessor32<scalar_t, nDim>();
|
||||
}
|
||||
#endif
|
||||
};
|
||||
#endif
|
||||
|
||||
template <typename integer>
|
||||
constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) {
|
||||
@@ -172,7 +178,8 @@ template <typename ArchTag>
|
||||
struct DefaultGemmType<
|
||||
ArchTag,
|
||||
float,
|
||||
typename std::enable_if<ArchTag::kMinComputeCapability >= 80>::type> {
|
||||
typename cutlass::platform::enable_if<
|
||||
ArchTag::kMinComputeCapability >= 80>::type> {
|
||||
static constexpr int ThreadK = 32;
|
||||
static constexpr int WarpK = 32;
|
||||
static constexpr int kMinimumAlignment = 4;
|
||||
@@ -186,7 +193,7 @@ template <typename ArchTag, typename scalar_t>
|
||||
struct DefaultGemmType<
|
||||
ArchTag,
|
||||
scalar_t,
|
||||
typename std::enable_if<
|
||||
typename cutlass::platform::enable_if<
|
||||
ArchTag::kMinComputeCapability >= 75 &&
|
||||
cutlass::sizeof_bits<scalar_t>::value == 16>::type> {
|
||||
static constexpr int ThreadK = 32;
|
||||
@@ -216,17 +223,19 @@ struct call_conditional;
|
||||
|
||||
template <typename TA, typename TB>
|
||||
struct call_conditional<true, TA, TB> {
|
||||
template <typename... Args>
|
||||
static CUTLASS_DEVICE auto apply(TA ta, TB tb, Args&&... args) -> decltype(ta(std::forward<Args>(args)...)) {
|
||||
return ta(std::forward<Args>(args)...);
|
||||
template <typename Arg>
|
||||
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
|
||||
-> decltype(ta(arg)) {
|
||||
return ta(arg);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename TA, typename TB>
|
||||
struct call_conditional<false, TA, TB> {
|
||||
template <typename... Args>
|
||||
static CUTLASS_DEVICE auto apply(TA ta, TB tb, Args&&... args) -> decltype(tb(std::forward<Args>(args)...)) {
|
||||
return tb(std::forward<Args>(args)...);
|
||||
template <typename Arg>
|
||||
static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg)
|
||||
-> decltype(tb(arg)) {
|
||||
return tb(arg);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -44,7 +44,8 @@ namespace {
|
||||
template <typename scalar_t, typename Arch>
|
||||
constexpr int getWarpsPerSm() {
|
||||
return (
|
||||
Arch::kMinComputeCapability >= 80 && !std::is_same<scalar_t, float>::value
|
||||
Arch::kMinComputeCapability >= 80 &&
|
||||
!cutlass::platform::is_same<scalar_t, float>::value
|
||||
? 16
|
||||
: 12);
|
||||
}
|
||||
@@ -75,8 +76,8 @@ struct AttentionKernel {
|
||||
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
|
||||
cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
|
||||
static constexpr bool kNeedsOutputAccumulatorBuffer =
|
||||
!kKeepOutputInRF && !std::is_same<output_accum_t, output_t>::value;
|
||||
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
|
||||
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
||||
|
||||
static_assert(kQueriesPerBlock % 32 == 0, "");
|
||||
static_assert(kKeysPerBlock % 32 == 0, "");
|
||||
@@ -413,7 +414,7 @@ struct AttentionKernel {
|
||||
}
|
||||
};
|
||||
|
||||
using SharedStorage = typename std::conditional<
|
||||
using SharedStorage = typename cutlass::platform::conditional<
|
||||
kSingleValueIteration || kKeepOutputInRF,
|
||||
SharedStorageEpilogueAtEnd,
|
||||
SharedStorageEpilogueInLoop>::type;
|
||||
@@ -446,8 +447,9 @@ struct AttentionKernel {
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
if (thread_id() < kQueriesPerBlock) {
|
||||
s_prime[thread_id()] = accum_t(0);
|
||||
m_prime[thread_id()] = gemm_kernel_utils::TypeTraits<accum_t>::kMinusInf;
|
||||
mi[thread_id()] = gemm_kernel_utils::TypeTraits<accum_t>::kMinusInf;
|
||||
m_prime[thread_id()] =
|
||||
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
typename MM1::Mma::FragmentC accum_o;
|
||||
accum_o.clear();
|
||||
@@ -580,7 +582,8 @@ struct AttentionKernel {
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (accum_n > last_col) {
|
||||
accum[idx] = gemm_kernel_utils::TypeTraits<accum_t>::kMinusInf;
|
||||
accum[idx] =
|
||||
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
@@ -609,7 +612,7 @@ struct AttentionKernel {
|
||||
warp_id(),
|
||||
p.num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
1.0f / std::sqrt(float(p.head_dim)));
|
||||
1.0f / cutlass::fast_sqrt(float(p.head_dim)));
|
||||
}));
|
||||
}));
|
||||
|
||||
@@ -683,7 +686,7 @@ struct AttentionKernel {
|
||||
using ElementCompute = typename DefaultOp::ElementCompute;
|
||||
using EpilogueOutputOp = typename cutlass::epilogue::
|
||||
thread::MemoryEfficientAttentionNormalize<
|
||||
typename std::conditional<
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
output_t,
|
||||
output_accum_t>::type,
|
||||
@@ -699,7 +702,7 @@ struct AttentionKernel {
|
||||
typename DefaultEpilogue::Shape,
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename std::conditional<
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
typename MM1::OutputTileIterator,
|
||||
typename MM1::OutputTileIteratorAccum>::type,
|
||||
@@ -787,11 +790,11 @@ struct AttentionKernel {
|
||||
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
|
||||
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
|
||||
if (thread_id() < p.num_queries) {
|
||||
p.logsumexp_ptr[thread_id()] =
|
||||
accum_t(mi[thread_id()]) + std::log(accum_t(s_prime[thread_id()]));
|
||||
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
|
||||
cutlass::fast_log(accum_t(s_prime[thread_id()]));
|
||||
} else if (thread_id() < lse_dim) {
|
||||
p.logsumexp_ptr[thread_id()] =
|
||||
gemm_kernel_utils::TypeTraits<accum_t>::kInf;
|
||||
cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1031,14 +1031,6 @@ class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory<
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
template <typename A, typename B>
|
||||
struct AssertIsSame {
|
||||
static_assert(std::is_same<A, B>::value);
|
||||
using CHECK = bool;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
template <
|
||||
typename WarpShape,
|
||||
typename InstructionShape,
|
||||
@@ -1264,7 +1256,8 @@ struct DefaultMmaFromSharedMemory<
|
||||
|
||||
static int constexpr kMaxK = AccumulatorSharedStorage_::Shape::kN;
|
||||
// Reduce the number of stages if we don't need that many
|
||||
static int constexpr kStagesMax = (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
|
||||
static int constexpr kStagesMax =
|
||||
(kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK);
|
||||
static int constexpr kStages = cutlass::const_min(Stages, kStagesMax);
|
||||
|
||||
using IteratorB =
|
||||
@@ -1630,7 +1623,7 @@ struct B2bGemm<
|
||||
if (rowIdx == 1) {
|
||||
lse_prefetched[colIdx] = accum_n < lse_extent
|
||||
? lse[accum_n]
|
||||
: std::numeric_limits<accum_t>::infinity();
|
||||
: platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
|
||||
++colIdx;
|
||||
@@ -1770,7 +1763,7 @@ struct B2bGemm<
|
||||
if (rowIdx == 1) {
|
||||
lse_prefetched[colIdx] = accum_n < lse_extent
|
||||
? lse[accum_n]
|
||||
: std::numeric_limits<accum_t>::infinity();
|
||||
: platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]);
|
||||
++colIdx;
|
||||
|
||||
@@ -574,6 +574,21 @@ using std::is_trivially_copyable;
|
||||
|
||||
#endif
|
||||
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// bit_cast <bit>
|
||||
//-----------------------------------------------------------------------------
|
||||
|
||||
template< class To, class From >
|
||||
constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& from ) noexcept;
|
||||
|
||||
template <class To, class From>
|
||||
constexpr To CUTLASS_HOST_DEVICE bit_cast(const From& src) noexcept
|
||||
{
|
||||
static_assert(sizeof(To) == sizeof(From), "sizes must match");
|
||||
return reinterpret_cast<To const &>(src);
|
||||
}
|
||||
|
||||
//-----------------------------------------------------------------------------
|
||||
// Alignment and layout utilities
|
||||
//-----------------------------------------------------------------------------
|
||||
@@ -865,5 +880,13 @@ struct numeric_limits<uint8_t> {
|
||||
static constexpr bool is_integer = true;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct numeric_limits<float> {
|
||||
CUTLASS_HOST_DEVICE
|
||||
static constexpr float infinity() noexcept { return bit_cast<float, int32_t>(0x7f800000);}
|
||||
static constexpr bool is_integer = false;
|
||||
static constexpr bool has_infinity = true;
|
||||
};
|
||||
|
||||
} // namespace platform
|
||||
} // namespace cutlass
|
||||
|
||||
Reference in New Issue
Block a user