/*************************************************************************************************** * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without * modification, are permitted provided that the following conditions are met: * * 1. Redistributions of source code must retain the above copyright notice, this * list of conditions and the following disclaimer. * * 2. Redistributions in binary form must reproduce the above copyright notice, * this list of conditions and the following disclaimer in the documentation * and/or other materials provided with the distribution. * * 3. Neither the name of the copyright holder nor the names of its * contributors may be used to endorse or promote products derived from * this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /*! \file \brief Defines a class for using integer types smaller than one byte in host or device code. */ /* Note: CUTLASS 3x increases the host compiler requirements to C++17. However, certain existing integrations of CUTLASS require C++11 host compilers. Until this requirement can be lifted, certain headers with this annotation are required to be remain consistent with C++11 syntax. C++11 compatibility is enforced by `cutlass_test_unit_core_cpp11`. */ #pragma once #if defined(__CUDACC_RTC__) #include #else #include #endif #include "cutlass/cutlass.h" #include "cutlass/numeric_size.h" #include "cutlass/platform/platform.h" namespace cutlass { /////////////////////////////////////////////////////////////////////////////////////////////////// /// 4-bit signed integer type template struct integer_subbyte { /// Number of bits static int const kBits = Bits; /// Whether type is signed static bool const kSigned = Signed; /// External type using T = typename platform::conditional::type; /// Storage type using Storage = uint8_t; /// Bitmask used to truncate from larger integers static Storage const kMask = Storage((1 << kBits) - 1); // // Data members // Storage storage; // // Methods // /// No operation integer_subbyte() = default; /// Conversion from integer type CUTLASS_HOST_DEVICE integer_subbyte(int value) : storage(reinterpret_cast(value) & kMask) {} CUTLASS_HOST_DEVICE integer_subbyte(unsigned value) : storage(reinterpret_cast(value) & kMask) {} CUTLASS_HOST_DEVICE integer_subbyte(double value) { T tmp = static_cast(value); storage = Storage(reinterpret_cast(tmp) & kMask); } /// CUTLASS_HOST_DEVICE operator T() const { if (kSigned) { // Sign extend if (storage & Storage(1 << (kBits - 1))) { return T(storage) | ~T(kMask); } } return T(storage); } /// Equality CUTLASS_HOST_DEVICE bool operator==(integer_subbyte const &rhs) const { return storage == rhs.storage; } /// Inequality CUTLASS_HOST_DEVICE bool operator!=(integer_subbyte const &rhs) const { return storage != rhs.storage; } /// Less than or equal CUTLASS_HOST_DEVICE bool operator<=(integer_subbyte const &rhs) const { if (kSigned) { if (storage & (1 << (kBits - 1))) { return !(rhs.storage < storage); } } return storage <= rhs.storage; } /// Less than CUTLASS_HOST_DEVICE bool operator<(integer_subbyte const &rhs) const { if (kSigned) { if (storage & (1 << (kBits - 1))) { return !(rhs.storage <= storage); } } return storage < rhs.storage; } /// Greater than or equal CUTLASS_HOST_DEVICE bool operator>=(integer_subbyte const &rhs) const { return !(*this < rhs); } /// Greater than CUTLASS_HOST_DEVICE bool operator>(integer_subbyte const &rhs) const { return !(*this <= rhs); } }; /////////////////////////////////////////////////////////////////////////////////////////////////// /// 1-bit Unsigned integer type using uint1b_t = integer_subbyte<1, false>; /// 2-bit Integer type using int2b_t = integer_subbyte<2, true>; /// 2-bit Unsigned integer type using uint2b_t = integer_subbyte<2, false>; /// 4-bit Integer type using int4b_t = integer_subbyte<4, true>; /// 4-bit Unsigned integer type using uint4b_t = integer_subbyte<4, false>; /////////////////////////////////////////////////////////////////////////////////////////////////// /// Defines the size of an element in bits - specialized for uint1b_t template <> struct sizeof_bits { static int const value = 1; }; /// Defines the size of an element in bits - specialized for int2b_t template <> struct sizeof_bits { static int const value = 2; }; /// Defines the size of an element in bits - specialized for uint2b_t template <> struct sizeof_bits { static int const value = 2; }; /// Defines the size of an element in bits - specialized for int4b_t template <> struct sizeof_bits { static int const value = 4; }; /// Defines the size of an element in bits - specialized for uint4b_t template <> struct sizeof_bits { static int const value = 4; }; /////////////////////////////////////////////////////////////////////////////////////////////////// namespace platform { template <> struct numeric_limits { CUTLASS_HOST_DEVICE static cutlass::int4b_t const lowest() noexcept { return -8;} CUTLASS_HOST_DEVICE static cutlass::int4b_t const max() noexcept { return 7;} static constexpr bool is_integer = true; }; template <> struct numeric_limits { CUTLASS_HOST_DEVICE static cutlass::uint4b_t const lowest() noexcept { return 0;} CUTLASS_HOST_DEVICE static cutlass::uint4b_t const max() noexcept { return 15;} static constexpr bool is_integer = true; }; template <> struct numeric_limits { CUTLASS_HOST_DEVICE static cutlass::uint1b_t const lowest() noexcept { return 0;} CUTLASS_HOST_DEVICE static cutlass::uint1b_t const max() noexcept { return 1;} static constexpr bool is_integer = true; }; /////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace platform } // namespace cutlass