mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-11 17:00:05 +00:00
Add support for mixed 4-bit/8-bit data types GEMM (#1413)
* Add support for mixed 4-bit/8-bit data types GEMM * fix ( and ) --------- Co-authored-by: Aleksandar Samardžić <asamardzic@matf.bg.ac.rs> Co-authored-by: Haicheng Wu <haichengw@nvidia.com>
This commit is contained in:
committed by
GitHub
parent
f7b19de32c
commit
e1976daacc
@@ -793,6 +793,60 @@ struct DefaultGemmConfigurationSm89F8 {
|
||||
using Operator = arch::OpMultiplyAdd;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementC>
|
||||
struct DefaultGemmConfiguration<
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm80,
|
||||
int4b_t,
|
||||
int8_t,
|
||||
ElementC,
|
||||
int32_t> {
|
||||
|
||||
static int const kAlignmentA = 128 / sizeof_bits<int4b_t>::value;
|
||||
static int const kAlignmentB = 128 / sizeof_bits<int8_t>::value;
|
||||
|
||||
using ThreadblockShape = GemmShape<128, 256, 64>;
|
||||
using WarpShape = GemmShape<64, 64, 64>;
|
||||
using InstructionShape = GemmShape<16, 8, 32>;
|
||||
static int const kStages = 3;
|
||||
|
||||
using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
|
||||
ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
|
||||
|
||||
using Operator = arch::OpMultiplyAddSaturate;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template <
|
||||
typename ElementC>
|
||||
struct DefaultGemmConfiguration<
|
||||
arch::OpClassTensorOp,
|
||||
arch::Sm80,
|
||||
int8_t,
|
||||
int4b_t,
|
||||
ElementC,
|
||||
int32_t> {
|
||||
|
||||
static int const kAlignmentA = 128 / sizeof_bits<int8_t>::value;
|
||||
static int const kAlignmentB = 128 / sizeof_bits<int4b_t>::value;
|
||||
|
||||
using ThreadblockShape = GemmShape<128, 256, 64>;
|
||||
using WarpShape = GemmShape<64, 64, 64>;
|
||||
using InstructionShape = GemmShape<16, 8, 32>;
|
||||
static int const kStages = 3;
|
||||
|
||||
using EpilogueOutputOp = epilogue::thread::LinearCombinationClamp<
|
||||
ElementC, 128 / sizeof_bits<ElementC>::value, int32_t, float>;
|
||||
|
||||
using Operator = arch::OpMultiplyAddSaturate;
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for SM89 fe4m3 x fe4m3
|
||||
template <typename ElementC, typename ElementAccumulator>
|
||||
struct DefaultGemmConfiguration<
|
||||
|
||||
@@ -268,7 +268,7 @@ struct DefaultMmaTensorOp<
|
||||
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");
|
||||
|
||||
// Data type used for internal computation - use the wider of the two data types for mma.sync operands
|
||||
using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)),
|
||||
using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
|
||||
ElementA, ElementB>::type;
|
||||
|
||||
// Operand datatypes in the internal MMA instruction - use the wider of the two data types
|
||||
@@ -294,6 +294,75 @@ struct DefaultMmaTensorOp<
|
||||
Policy, PartitionsK, AccumulatorsInRowMajor>;
|
||||
};
|
||||
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
|
||||
/// (e.g. S32 <= S4 x S8 + S32, S32 <= S8 x S4 + S32)
|
||||
template <
|
||||
/// Shape of one matrix production operation (concept: GemmShape)
|
||||
typename WarpShape_,
|
||||
/// Element type of A matrix
|
||||
typename ElementA,
|
||||
/// Layout of A matrix (concept: MatrixLayout)
|
||||
typename LayoutA,
|
||||
/// Element type of B matrix
|
||||
typename ElementB,
|
||||
/// Layout of B matrix (concept: MatrixLayout)
|
||||
typename LayoutB,
|
||||
/// Element type of C matrix
|
||||
typename ElementC,
|
||||
/// Layout of C matrix (concept: MatrixLayout)
|
||||
typename LayoutC,
|
||||
/// Number of partitions along K dimension
|
||||
int PartitionsK,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
bool AccumulatorsInRowMajor>
|
||||
struct DefaultMmaTensorOp<
|
||||
WarpShape_,
|
||||
GemmShape<16, 8, 32>, // InstructionShape
|
||||
ElementA, // Element type of A matrix in Global Memory
|
||||
LayoutA, // Layout of A matrix in Global Memory
|
||||
ElementB, // Element type of B matrix in Global Memory
|
||||
LayoutB, // Layout of B matrix in Global Memory
|
||||
ElementC, // Element type of C matrix in Global Memory
|
||||
LayoutC, // Layout of C matrix in Global Memory
|
||||
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
|
||||
PartitionsK, AccumulatorsInRowMajor> {
|
||||
|
||||
|
||||
// Check if the ElementA and ElementB are of different data types
|
||||
static_assert(!platform::is_same<ElementA, ElementB>::value,
|
||||
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");
|
||||
|
||||
// Data type used for internal computation - use the wider of the two data types for mma.sync operands
|
||||
using ElementOperand = typename platform::conditional<(sizeof_bits<ElementA>::value > sizeof_bits<ElementB>::value),
|
||||
ElementA, ElementB>::type;
|
||||
|
||||
// Operand datatypes in the internal MMA instruction - use the wider of the two data types
|
||||
using MmaElementA = ElementOperand;
|
||||
using MmaElementB = ElementOperand;
|
||||
using MmaElementC = ElementC;
|
||||
|
||||
// Uses
|
||||
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
|
||||
cutlass::arch::Mma<
|
||||
GemmShape<16, 8, 32>,
|
||||
32,
|
||||
MmaElementA, cutlass::layout::RowMajor,
|
||||
MmaElementB, cutlass::layout::ColumnMajor,
|
||||
MmaElementC, cutlass::layout::RowMajor,
|
||||
arch::OpMultiplyAddSaturate
|
||||
>,
|
||||
cutlass::MatrixShape<1, 1> >;
|
||||
|
||||
// Define the warp-level tensor op
|
||||
using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
|
||||
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
|
||||
Policy, PartitionsK, AccumulatorsInRowMajor>;
|
||||
};
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace warp
|
||||
|
||||
@@ -104,6 +104,7 @@ struct FragmentShuffler {
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
|
||||
/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4)
|
||||
/// for operand A multiplicand going through upcasting.
|
||||
template <
|
||||
/// Element type for the operand in registers for the mma.sync
|
||||
@@ -122,8 +123,10 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
|
||||
NumElementsInWarpFragment,
|
||||
NumElementsInMmaFragment,
|
||||
Operand::kA,
|
||||
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
|
||||
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
|
||||
typename platform::enable_if<((sizeof_bits<ElementMma_>::value == 16) &&
|
||||
(sizeof_bits<ElementLoad_>::value == 8)) ||
|
||||
((sizeof_bits<ElementMma_>::value == 8) &&
|
||||
(sizeof_bits<ElementLoad_>::value == 4))>::type> {
|
||||
public:
|
||||
using ElementMma = ElementMma_;
|
||||
using ElementLoad = ElementLoad_;
|
||||
@@ -187,6 +190,7 @@ public:
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
|
||||
/// or for `mma.sync` on 8b (S8/U8) and `ldmatrix` on 4b (S4/U4)
|
||||
/// for operand B multiplicand going through upcasting.
|
||||
template <
|
||||
/// Element type for the operand in registers for the mma.sync
|
||||
@@ -205,8 +209,10 @@ struct FragmentShuffler <ElementMma_, ElementLoad_,
|
||||
NumElementsInWarpFragment,
|
||||
NumElementsInMmaFragment,
|
||||
Operand::kB,
|
||||
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
|
||||
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
|
||||
typename platform::enable_if<((sizeof_bits<ElementMma_>::value == 16) &&
|
||||
(sizeof_bits<ElementLoad_>::value == 8)) ||
|
||||
((sizeof_bits<ElementMma_>::value == 8) &&
|
||||
(sizeof_bits<ElementLoad_>::value == 4))>::type> {
|
||||
public:
|
||||
using ElementMma = ElementMma_;
|
||||
using ElementLoad = ElementLoad_;
|
||||
|
||||
@@ -2771,6 +2771,86 @@ struct NumericArrayConverter<uint4b_t, int, N, Round> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<int8_t, 8> <= Array<int4b_t, 8>
|
||||
template <
|
||||
FloatRoundStyle Round
|
||||
>
|
||||
struct NumericArrayConverter<int8_t, int4b_t, 8, Round> {
|
||||
|
||||
using result_type = Array<int8_t, 8>;
|
||||
using source_type = Array<int4b_t, 8>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static result_type convert(source_type const & source) {
|
||||
|
||||
unsigned const& storage = reinterpret_cast<unsigned const &>(source);
|
||||
unsigned out[2];
|
||||
|
||||
asm volatile(
|
||||
"{ .reg .u32 tmp0, tmp1, tmp2;"
|
||||
"shl.b32 tmp0, %2, 4;"
|
||||
"and.b32 tmp0, tmp0, 0xf0f0f0f0;"
|
||||
"prmt.b32 tmp1, tmp0, tmp0, 0xba98;"
|
||||
"and.b32 tmp1, tmp1, 0xf0f0f0f0;"
|
||||
"shr.u32 tmp0, tmp0, 4;"
|
||||
"or.b32 tmp2, tmp0, tmp1;"
|
||||
"and.b32 tmp0, %2, 0xf0f0f0f0;"
|
||||
"prmt.b32 tmp1, tmp0, tmp0, 0xba98;"
|
||||
"and.b32 tmp1, tmp1, 0xf0f0f0f0;"
|
||||
"shr.u32 tmp0, tmp0, 4;"
|
||||
"or.b32 tmp0, tmp0, tmp1;"
|
||||
"prmt.b32 %0, tmp2, tmp0, 0x5140;"
|
||||
"prmt.b32 %1, tmp2, tmp0, 0x7362;"
|
||||
"}"
|
||||
: "=r"(out[0]), "=r"(out[1])
|
||||
: "r"(storage));
|
||||
|
||||
return reinterpret_cast<result_type const &>(out);
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
/// Partial specialization for Array<int8_t> <= Array<int4b_t>
|
||||
template <
|
||||
int N,
|
||||
FloatRoundStyle Round
|
||||
>
|
||||
struct NumericArrayConverter<int8_t, int4b_t, N, Round> {
|
||||
static_assert(!(N % 8), "N must be multiple of 8.");
|
||||
|
||||
using result_type = Array<int8_t, N>;
|
||||
using source_type = Array<int4b_t, N>;
|
||||
static FloatRoundStyle const round_style = Round;
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
static result_type convert(source_type const & source) {
|
||||
|
||||
NumericArrayConverter<int8_t, int4b_t, 8, Round> convert_vector_;
|
||||
|
||||
result_type result;
|
||||
|
||||
Array<int8_t, 8> *result_ptr = reinterpret_cast<Array<int8_t, 8> *>(&result);
|
||||
Array<int4b_t, 8> const *source_ptr = reinterpret_cast<Array<int4b_t, 8> const *>(&source);
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < N / 8; ++i) {
|
||||
result_ptr[i] = convert_vector_(source_ptr[i]);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
result_type operator()(source_type const &s) const {
|
||||
return convert(s);
|
||||
}
|
||||
};
|
||||
|
||||
#endif // Conditional guards to enable partial specialization for packed integers
|
||||
|
||||
namespace detail {
|
||||
|
||||
Reference in New Issue
Block a user