Support for Mixed Input TensorOp (#1084)

* Passing warp-level mixed input F16*(S8/U8) tests

* passing device-level mixed input F16*(S8/U8) tests

* add to profiler - I8 (111 TFLOPs), U (123 TFLOPs)

* fast numeric conversions (I8 = 132 TFLOPs, U8 = 148 TFLOPs)

* Speedup reference compilation (REVERT THIS COMMIT)

* wider_add.u32_packed_sub.f16x2 (I8 = 132TFLOP/s, U8 = 170 TFLOP/s)

* Improve s8->f16 cvt and support bf16*u8 @158 TFLOPs

* BF16 * S8 (142 TFLOPs)

* Handle mixed-input upcast on OperandA (Support [S8|U8]*[F16|BF16]

* rename OpMultiplyAddMixedInput to OpMultiplyAddMixedInputUpcast

* Add device-level test and profiler support for upcast on operand A

* Move shfl before the cvt and reduce #shfls by 1/2

* fix smem_usage calculation for mixed_input types

* uncomment the stuff (getting ready for merge)

* profiler changes and mixed-input reference

* mixed input reference are in a new file

* use platform instead of std

* comments and typo only

* Use CreateGemmOperator and delete CreateMixedInputGemmOperator

* copyright for new files

* rebase follow-up
This commit is contained in:
Manish Gupta
2023-09-27 08:18:30 -07:00
committed by GitHub
parent 5cd735c48e
commit 7d8317a63e
26 changed files with 2064 additions and 13 deletions

View File

@@ -68,14 +68,24 @@ struct OpMultiplyAddFastF16 {};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the input data types are mixed and the narrower type is
/// upcasted to the wider type
struct OpMultiplyAddMixedInputUpcast {};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the input is converted to 2 (big and small) TF32 components
// Perform 3xTF32 or 4xTF32 for every F32 output element
struct OpMultiplyAddFastF32 {};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tag indicating the input is converted to 2 (big and small) TF32 components
// Perform 3xTF32 or 4xTF32 for every complex<F32> output element
struct OpMultiplyAddComplexFastF32 {};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Helper for determining whether staged accumulation should be used for a given operator
template <typename Operator>
struct UseStagedAccumulation {

View File

@@ -38,6 +38,7 @@
#include "cutlass/numeric_types.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/warp/mma_tensor_op.h"
#include "cutlass/gemm/warp/mma_mixed_input_tensor_op.h"
#include "cutlass/gemm/warp/mma_tensor_op_fast_f32.h"
#include "cutlass/gemm/warp/default_mma_tensor_op.h"
@@ -227,6 +228,72 @@ struct DefaultMmaTensorOp<
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Partial Specialization - inputs are mixed types - uses wider datatype internally.
/// (e.g. F16 <= F16 x S8 + F16, F16 <= BF16 x S8 + F32)
template <
/// Shape of one matrix production operation (concept: GemmShape)
typename WarpShape_,
/// Element type of A matrix
typename ElementA,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA,
/// Element type of B matrix
typename ElementB,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB,
/// Element type of C matrix
typename ElementC,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC,
/// Number of partitions along K dimension
int PartitionsK,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor>
struct DefaultMmaTensorOp<
WarpShape_,
GemmShape<16, 8, 16>, // InstructionShape
ElementA, // Element type of A matrix in Global Memory
LayoutA, // Layout of A matrix in Global Memory
ElementB, // Element type of B matrix in Global Memory
LayoutB, // Layout of B matrix in Global Memory
ElementC, // Element type of C matrix in Global Memory
LayoutC, // Layout of C matrix in Global Memory
arch::OpMultiplyAddMixedInputUpcast, // Tag to indicate mixed-input datatype, where narrower datatype is upcasted to wider datatype
PartitionsK, AccumulatorsInRowMajor> {
// Check if the ElementA and ElementB are of different data types
static_assert(!platform::is_same<ElementA, ElementB>::value,
"DefaultMmaTensorOp with arch::OpMultiplyAddMixedInputUpcast ElementA and ElementB cannot be of the same data type");
// Data type used for internal computation - use the wider of the two data types for mma.sync operands
using ElementOperand = typename platform::conditional<(sizeof(ElementA) > sizeof(ElementB)),
ElementA, ElementB>::type;
// Operand datatypes in the internal MMA instruction - use the wider of the two data types
using MmaElementA = ElementOperand;
using MmaElementB = ElementOperand;
using MmaElementC = ElementC;
// Uses
using Policy = cutlass::gemm::warp::MmaTensorOpPolicy<
cutlass::arch::Mma<
GemmShape<16, 8, 16>,
32,
MmaElementA, cutlass::layout::RowMajor,
MmaElementB, cutlass::layout::ColumnMajor,
MmaElementC, cutlass::layout::RowMajor,
arch::OpMultiplyAdd
>,
cutlass::MatrixShape<1, 1> >;
// Define the warp-level tensor op
using Type = cutlass::gemm::warp::MmaMixedInputTensorOp<
WarpShape_, ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
Policy, PartitionsK, AccumulatorsInRowMajor>;
};
} // namespace warp
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,554 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates implementing warp-level matrix multiply-accumulate operations targeting
Tensor Cores.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/array.h"
#include "cutlass/platform/platform.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/mma_sm80.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/warp/mma.h"
#include "cutlass/gemm/warp/mma_tensor_op_policy.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator.h"
#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass {
namespace gemm {
namespace warp {
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
////////////////////////////////////////////////////////////////////////////////
// Shuffle registers for layout conversion
////////////////////////////////////////////////////////////////////////////////
template <
/// Element type for the operand in registers for the mma.sync
typename ElementMma_,
/// Element type for the operand in shared memory for ldmatrix
typename ElementLoad_,
/// Number of mma.sync operations performed along rows or columns
int NumMmaInstructions,
/// Number of elements in warp fragment
int NumElementsInWarpFragment,
/// Number of elements in mma fragment
int NumElementsInMmaFragment,
/// Identifies A or B multiplicand
Operand Operand_,
///
typename Enable = void >
struct FragmentShuffler {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;
static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand_;
using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;
CUTLASS_DEVICE
WarpFragment operator()(WarpFragment const &src) {
return src;
}
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
/// for operand A multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
typename ElementMma_,
/// Element type for the operand in shared memory for ldmatrix
typename ElementLoad_,
/// Number of mma.sync operations performed along rows or columns
int NumMmaInstructions,
/// Number of elements in warp fragment
int NumElementsInWarpFragment,
/// Number of elements in mma fragment
int NumElementsInMmaFragment
>
struct FragmentShuffler <ElementMma_, ElementLoad_,
NumMmaInstructions,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kA,
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;
static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand::kA;
using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;
static uint32_t const kSelectBytesEvenThread = 0x5410;
static uint32_t const kSelectBytesOddThread = 0x7632;
private:
int delta_up_;
int delta_down_;
int odd_even_lane_id_;
uint32_t byte_selector_;
public:
CUTLASS_DEVICE
FragmentShuffler() {
int lane_id = cutlass::arch::LaneId();
delta_up_ = (lane_id & 1) + ((lane_id & 2) >> 1);
delta_down_ = 2 - delta_up_;
odd_even_lane_id_ = static_cast<int>(lane_id & 1);
byte_selector_ = odd_even_lane_id_ * kSelectBytesOddThread +
(1 - odd_even_lane_id_) * kSelectBytesEvenThread;
}
CUTLASS_DEVICE
WarpFragment operator()(WarpFragment const &src) {
WarpFragment result;
MmaFragment const* mma_frag_src_ptr = reinterpret_cast<MmaFragment const*>(&src);
MmaFragment* mma_frag_dst_ptr = reinterpret_cast<MmaFragment*>(&result);
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kNumMmaInstructions; n++) {
uint32_t const* src_ptr = reinterpret_cast<uint32_t const *>(&mma_frag_src_ptr[n]);
uint32_t *dst_ptr = reinterpret_cast<uint32_t *>(&mma_frag_dst_ptr[n]);
// Shuffle data within the warp, pull from other threads within the warp
uint32_t tmp0 = __shfl_up_sync(0xFFFFFFFF, src_ptr[0], delta_up_);
uint32_t tmp1 = __shfl_down_sync(0xFFFFFFFF, src_ptr[0], delta_down_);
uint32_t tmp2 = __shfl_up_sync(0xFFFFFFFF, src_ptr[1], delta_up_);
uint32_t tmp3 = __shfl_down_sync(0xFFFFFFFF, src_ptr[1], delta_down_);
// Reorder the data within the 32-bit word (4x8b) required for mma.sync
dst_ptr[0] = __byte_perm(tmp0, tmp2, byte_selector_);
dst_ptr[1] = __byte_perm(tmp1, tmp3, byte_selector_);
}
return result;
}
};
////////////////////////////////////////////////////////////////////////////////
/// Partial specialization for `mma.sync` on 16b (F16/BF16) and `ldmatrix` on 8b (S8/U8)
/// for operand B multiplicand going through upcasting.
template <
/// Element type for the operand in registers for the mma.sync
typename ElementMma_,
/// Element type for the operand in shared memory for ldmatrix
typename ElementLoad_,
/// Number of mma.sync operations performed along rows or columns
int NumMmaInstructions,
/// Number of elements in warp fragment
int NumElementsInWarpFragment,
/// Number of elements in mma fragment
int NumElementsInMmaFragment
>
struct FragmentShuffler <ElementMma_, ElementLoad_,
NumMmaInstructions,
NumElementsInWarpFragment,
NumElementsInMmaFragment,
Operand::kB,
typename platform::enable_if<(sizeof_bits<ElementMma_>::value == 16) &&
(sizeof_bits<ElementLoad_>::value == 8)>::type> {
public:
using ElementMma = ElementMma_;
using ElementLoad = ElementLoad_;
static int const kNumMmaInstructions = NumMmaInstructions;
static int const kNumElementsInWarpFragment = NumElementsInWarpFragment;
static int const kNumElementsInMmaFragment = NumElementsInMmaFragment;
static Operand const kOperand = Operand::kB;
using WarpFragment = Array<ElementLoad, kNumElementsInWarpFragment>;
using MmaFragment = Array<ElementLoad, kNumElementsInMmaFragment>;
static uint32_t const kSelectBytesEvenThread = 0x5410;
static uint32_t const kSelectBytesOddThread = 0x7632;
private:
int delta_up_;
int delta_down_;
int odd_even_lane_id_;
uint32_t byte_selector_;
public:
CUTLASS_DEVICE
FragmentShuffler() {
int lane_id = cutlass::arch::LaneId();
delta_up_ = (lane_id & 1) + ((lane_id & 2) >> 1);
delta_down_ = 2 - delta_up_;
odd_even_lane_id_ = static_cast<int>(lane_id & 1);
byte_selector_ = odd_even_lane_id_ * kSelectBytesOddThread +
(1 - odd_even_lane_id_) * kSelectBytesEvenThread;
}
CUTLASS_DEVICE
WarpFragment operator()(WarpFragment const &src) {
WarpFragment result;
MmaFragment const* mma_frag_src_ptr = reinterpret_cast<MmaFragment const *>(&src);
MmaFragment* mma_frag_dst_ptr = reinterpret_cast<MmaFragment *>(&result);
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < kNumMmaInstructions; n++) {
uint32_t const* src_ptr = reinterpret_cast<uint32_t const*>(&mma_frag_src_ptr[n]);
uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&mma_frag_dst_ptr[n]);
// Shuffle data within the warp, pull from other threads within the warp
uint32_t tmp0 = __shfl_up_sync(0xFFFFFFFF, src_ptr[0], delta_up_);
uint32_t tmp1 = __shfl_down_sync(0xFFFFFFFF, src_ptr[0], delta_down_);
// Reorder the data within the 32-bit word (4x8b) required for mma.sync
dst_ptr[0] = __byte_perm(tmp0, tmp1, byte_selector_);
}
return result;
}
};
////////////////////////////////////////////////////////////////////////////////
// Data type conversion
////////////////////////////////////////////////////////////////////////////////
template <
/// Destination type
typename ElementDst_,
/// Source type
typename ElementSrc_,
/// Number of elements
int N,
///
typename Enable = void>
struct FragmentConverter {
using ElementDst = ElementDst_;
using ElementSrc = ElementSrc_;
// Operand fragment registers in destination and source types
using DestinationFragment = Array<ElementDst, N>;
using SourceFragment = Array<ElementSrc, N>;
FastNumericArrayConverter<ElementDst, ElementSrc, N> convert;
CUTLASS_DEVICE
DestinationFragment operator()(SourceFragment const &src) const {
return convert(src);
}
};
////////////////////////////////////////////////////////////////////////////////
// Partial specialization for when Destination type is the *same* as
// Source type
template <
/// Data type
typename Element,
/// Number of elements
int N,
///
typename Enable>
struct FragmentConverter<Element, Element, N, Enable> {
using DestinationFragment = Array<Element, N>;
using SourceFragment = Array<Element, N>;
CUTLASS_DEVICE
DestinationFragment operator()(SourceFragment const &src) const {
return src;
}
};
} // namespace detail
/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions.
template <
/// Size of the Gemm problem - concept: gemm::GemmShape<>
typename Shape_,
/// Data type of A elements
typename ElementA_,
/// Layout of A matrix (concept: MatrixLayout)
typename LayoutA_,
/// Data type of B elements
typename ElementB_,
/// Layout of B matrix (concept: MatrixLayout)
typename LayoutB_,
/// Element type of C matrix
typename ElementC_,
/// Layout of C matrix (concept: MatrixLayout)
typename LayoutC_,
/// Policy describing warp-level MmaTensorOp (concept: MmaTensorOp policy)
typename Policy_,
/// Number of partitions along K dimension
int PartitionsK_ = 1,
/// Store the accumulators in row major or column major. Row major is used
/// when output layout is interleaved.
bool AccumulatorsInRowMajor = false,
/// Used for partial specialization
typename Enable = bool
>
class MmaMixedInputTensorOp {
public:
/// Shape of warp-level matrix operation (concept: GemmShape)
using Shape = Shape_;
/// Data type of multiplicand A
using ElementA = ElementA_;
/// Layout of multiplicand A
using LayoutA = LayoutA_;
/// Data type of multiplicand B
using ElementB = ElementB_;
/// Layout of multiplicand B
using LayoutB = LayoutB_;
/// Data type of accumulator matrix C
using ElementC = ElementC_;
/// Layout of accumulator matrix C
using LayoutC = LayoutC_;
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
using Policy = Policy_;
/// Underlying matrix multiply operator (concept: arch::Mma)
using ArchMmaOperator = typename Policy::Operator;
/// Underlying arch::Mma instruction datatype for A operand
using MmaElementA = typename ArchMmaOperator::ElementA;
/// Underlying arch::Mma instruction datatype for B operand
using MmaElementB = typename ArchMmaOperator::ElementB;
/// Underlying arch::Mma instruction datatype for C operand
using MmaElementC = typename ArchMmaOperator::ElementC;
/// Indicates math operator
using MathOperator = typename ArchMmaOperator::Operator;
/// Architecture tag from underlying instruction
using ArchTag = typename ArchMmaOperator::ArchTag;
/// Indicates class of matrix operator
using OperatorClass = arch::OpClassTensorOp;
/// Shape of underlying instruction
using InstructionShape = typename ArchMmaOperator::Shape;
/// Complex transform on A operand
static ComplexTransform const kTransformA = ComplexTransform::kNone;
/// Complex transform on B operand
static ComplexTransform const kTransformB = ComplexTransform::kNone;
/// Number of threads participating in warp-level matrix product
static int const kThreadCount = 32;
/// Number of partitions along K dimension
static int const kPartitionsK = PartitionsK_;
///
// static int const kLoadShapeK = InstructionShape::kK *
// (sizeof_bits<MmaElementA>::value / sizeof_bits<ElementB>::value);
public:
/// Iterates over the A operand in Shared Memory
using IteratorA = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kM, Shape::kK>, Operand::kA, ElementA, LayoutA,
MatrixShape<ArchMmaOperator::Shape::kM, ArchMmaOperator::Shape::kK>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for A tile in registers (loaded from Shared Memory)
using FragmentA = typename IteratorA::Fragment;
/// Storage for transformed A tile in registers (for use in Mma instruction)
using TransformedFragmentA =
Array<MmaElementA, FragmentA::kElements>;
/// Underlying arch::Mma instruction operand fragement for matrix A
using MmaOperandA = typename ArchMmaOperator::FragmentA;
/// Iterates over the B operand in Shared Memory
using IteratorB = MmaTensorOpMultiplicandTileIterator<
MatrixShape<Shape::kK, Shape::kN>, Operand::kB, ElementB, LayoutB,
MatrixShape<ArchMmaOperator::Shape::kK, ArchMmaOperator::Shape::kN>,
Policy::OpDelta::kRow, kThreadCount, kPartitionsK>;
/// Storage for B tile in registers (loaded from Shared Memory)
using FragmentB = typename IteratorB::Fragment;
/// Storage for transformed B tile in registers (for use in Mma instruction)
using TransformedFragmentB =
Array<MmaElementB, FragmentB::kElements>;
/// Underlying arch::Mma instruction operand fragement for matrix B
using MmaOperandB = typename ArchMmaOperator::FragmentB;
/// Iterates over the C operand in memory
using IteratorC = MmaTensorOpAccumulatorTileIterator<
MatrixShape<Shape::kM, Shape::kN>, ElementC, LayoutC,
typename ArchMmaOperator::Shape, typename Policy::OpDelta>;
/// Storage for C tile
using FragmentC = typename IteratorC::Fragment;
/// Underlying arch::Mma instruction operand fragement for matrix C
using MmaOperandC = typename ArchMmaOperator::FragmentC;
/// Number of mma operations performed
using MmaIterations = MatrixShape<
(Shape::kM + ArchMmaOperator::Shape::kM - 1) / ArchMmaOperator::Shape::kM,
(Shape::kN + ArchMmaOperator::Shape::kN - 1) / ArchMmaOperator::Shape::kN
>;
public:
/// Underlying matrix multiply operator (concept: arch::Mma)
ArchMmaOperator mma;
public:
//
// Methods
//
/// Ctor
CUTLASS_DEVICE
MmaMixedInputTensorOp() {}
/// Performs a warp-level matrix multiply-accumulate operation
CUTLASS_DEVICE
void operator()(
FragmentC &D,
TransformedFragmentA const &A,
TransformedFragmentB const &B,
FragmentC const &C
) const {
D = C;
MmaOperandA const *ptr_A = reinterpret_cast<MmaOperandA const *>(&A);
MmaOperandB const *ptr_B = reinterpret_cast<MmaOperandB const *>(&B);
MmaOperandC *ptr_D = reinterpret_cast<MmaOperandC *>(&D);
CUTLASS_PRAGMA_UNROLL
for (int m = 0; m < MmaIterations::kRow; ++m) {
CUTLASS_PRAGMA_UNROLL
for (int n = 0; n < MmaIterations::kColumn; ++n) {
int n_serpentine = ((m % 2) ? (MmaIterations::kColumn - 1 - n) : n);
if (AccumulatorsInRowMajor) { // matrix B is reordered
mma(
ptr_D[n_serpentine + m * MmaIterations::kColumn],
ptr_A[m],
ptr_B[n_serpentine],
ptr_D[n_serpentine + m * MmaIterations::kColumn]);
} else {
mma(ptr_D[m + n_serpentine * MmaIterations::kRow],
ptr_A[m],
ptr_B[n_serpentine],
ptr_D[m + n_serpentine * MmaIterations::kRow]);
}
}
}
}
/// Transform the operand warp fragment register to the required data types and layout
/// for the `cultass::arch::Mma`
CUTLASS_DEVICE
void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B,
FragmentA const &A, FragmentB const &B) const {
// Shuffle data within warp to obtain the mma.sync operand layout
detail::FragmentShuffler<MmaElementA, ElementA, MmaIterations::kRow,
FragmentA::kElements, MmaOperandA::kElements, Operand::kA> shuffler_A;
FragmentA tmp_A;
tmp_A = shuffler_A(A);
// Convert the A operand to the Mma Instruction operand type
detail::FragmentConverter<MmaElementA, ElementA, FragmentA::kElements> convert_A;
dst_A = convert_A(tmp_A);
// Shuffle data within warp to obtain the mma.sync operand layout
detail::FragmentShuffler<MmaElementB, ElementB, MmaIterations::kColumn,
FragmentB::kElements, MmaOperandB::kElements, Operand::kB> shuffler_B;
FragmentB tmp_B;
tmp_B = shuffler_B(B);
// Convert the B operand to the Mma Instruction operand type
detail::FragmentConverter<MmaElementB, ElementB, FragmentB::kElements> convert_B;
dst_B = convert_B(tmp_B);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace warp
} // namespace gemm
} // namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////

View File

@@ -2340,7 +2340,8 @@ struct NumericArrayConverter<uint4b_t, int, N, Round> {
/// Conversion operator for Array. See the comments before
/// FastLinearCombinationClamp.
template <typename T, typename S, int N,
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest>
FloatRoundStyle Round = FloatRoundStyle::round_to_nearest,
typename Enable = void>
struct FastNumericArrayConverter {
using result_type = Array<T, N>;
using source_type = Array<S, N>;
@@ -2441,6 +2442,225 @@ struct FastNumericArrayConverter<int8_t, float, N, Round> {
result_type operator()(source_type const &s) const { return convert(s); }
};
/// Partial specialization for Array<cutlass::half_t, 4> <= Array<int8_t, 4>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::half_t, int8_t, 4, Round> {
using result_type = Array<cutlass::half_t, 4>;
using source_type = Array<int8_t, 4>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
#if 0 // Scalar conversion (Please keep this code for reference for vectorized version below)
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
int16_t tmp = source[i] + 26112 /* 0x6600 */;
result[i] = reinterpret_cast<cutlass::half_t const &>(tmp) - 1536.0_hf;
}
#endif
// Vectorized s8->f16 conversion using packed instructions
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
uint32_t* result_ptr = reinterpret_cast<uint32_t*>(&result);
// Pack s8x2 (s8[1], s8[0]) -> s16x2 (sext.s8[1], sext.s8[0])
// (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt)
// The inline ptx below uses `msb=0` and `msb=1` from the above link to sign extend the sign-bit in 0, 1, 2, 3 bytes of s8x4
// into result_ptr[0] and result_ptr[1]'s 08-15 and 24-31 bits, respectively.
// Note that `__byte_perm(source_ptr[0], source_ptr[0], 0x9180);` won't achieve the same and doesn't sign extend the sign-bit.
// Thus, we use inline ptx `prmt.b32` instruction for the desired sign extend from `s8x2` to `s16x2`.
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_ptr[0]) : "r"(source_ptr[0]), "n"(0x9180));
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(result_ptr[1]) : "r"(source_ptr[0]), "n"(0xB3A2));
// In the absense of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve
// the same result as add.s16x2 instruction.
// (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3)
// For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to
// three predefined constant values as follows:
// ta = 0xF0;
// tb = 0xCC;
// tc = 0xAA;
// kImmLut = F(ta, tb, tc);
// If we want F = ((a & b) ^ c) then set kImmLut = (0xF0 & 0xCC) ^ 0xAA
static constexpr uint32_t kImmLut = (0xF0 & 0xCC) ^ 0xAA;
// The bit-wise operation executed below is `result_ptr[0] = (result_ptr[0] & 0x03FF03FF) ^ 0x66006600;`
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" :
"=r"(result_ptr[0]) : "r"(result_ptr[0]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut));
// The bit-wise operation executed below is `result_ptr[1] = (result_ptr[1] & 0x03FF03FF) ^ 0x66006600;`
asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" :
"=r"(result_ptr[1]) : "r"(result_ptr[1]), "n"(0x03FF03FF), "n"(0x66006600), "n"(kImmLut));
// Packed sub.f16x2 with magic number to obtain final converted result
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for Array<cutlass::half_t, 4> <= Array<uint8_t, 4>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::half_t, uint8_t, 4, Round> {
using result_type = Array<cutlass::half_t, 4>;
using source_type = Array<uint8_t, 4>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
uint32_t* result_ptr = reinterpret_cast<uint32_t*>(&result);
result_ptr[0] = __byte_perm(source_ptr[0], 0x0, 0x4140);
result_ptr[1] = __byte_perm(source_ptr[0], 0x0, 0x4342);
asm volatile("add.u32 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600));
asm volatile("add.u32 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[0]) : "r"(result_ptr[0]), "r"(0x66006600));
asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(result_ptr[1]) : "r"(result_ptr[1]), "r"(0x66006600));
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for Array<cutlass::bfloat16_t, 4> <= Array<uint8_t, 4>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::bfloat16_t, uint8_t, 4, Round> {
using result_type = Array<cutlass::bfloat16_t, 4>;
using source_type = Array<uint8_t, 4>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
Array<float, 4> tmp;
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
uint32_t* tmp_ptr = reinterpret_cast<uint32_t*>(&tmp);
// __byte_perm simulates the add.u32 0x4B000000 to every u8 element of u8x4 source and stores
// the result in tmp (without introducing extra cvt.u32.u8 instruction)
tmp_ptr[0] = __byte_perm(source_ptr[0], 0x4B000000, 0x7650);
tmp_ptr[1] = __byte_perm(source_ptr[0], 0x4B000000, 0x7651);
tmp_ptr[2] = __byte_perm(source_ptr[0], 0x4B000000, 0x7652);
tmp_ptr[3] = __byte_perm(source_ptr[0], 0x4B000000, 0x7653);
// Subtract the magic number 0x4B000000 from tmp in floating-point arithmetic to obtain final result
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 4; ++i) {
tmp[i] = reinterpret_cast<float const &>(tmp_ptr[i]) - 8388608.f;
}
// on 3456x4096x8192 runs at 158 TFLOP/s
// Convert f32x2 to bf16x2 using `cvt.rn.b16x2.f32` instruction
NumericArrayConverter<cutlass::bfloat16_t, float, 4, Round> convert_f32_to_bf16;
result = convert_f32_to_bf16(tmp);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for Array<cutlass::bfloat16_t, 4> <= Array<int8_t, 4>
template <FloatRoundStyle Round>
struct FastNumericArrayConverter<cutlass::bfloat16_t, int8_t, 4, Round> {
using result_type = Array<cutlass::bfloat16_t, 4>;
using source_type = Array<int8_t, 4>;
using intermediate_float_type = Array<float, 4>;
using intermediate_int32_type = Array<int32_t, 4>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
result_type result;
intermediate_float_type tmp;
uint32_t const* source_ptr = reinterpret_cast<uint32_t const*>(&source);
uint32_t* tmp_ptr = reinterpret_cast<uint32_t*>(&tmp);
// s8x4 (s[3], s[2], s8[1], s8[0]) -> s16x4 (sext.s8[3], sext.s8[2], sext.s8[1], sext.s8[0])
// (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-prmt)
// The inline ptx below uses `msb=0` and `msb=1` from the above link to sext the sign-bit in 0, 1, 2, 3 bytes of s8x4
// sext without unpacking each s8 out of s8x4 into a separate register a.ka. without using shifts (SHFL).
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[0]) : "r"(source_ptr[0]), "n"(0x8880));
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[1]) : "r"(source_ptr[0]), "n"(0x9991));
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[2]) : "r"(source_ptr[0]), "n"(0xAAA2));
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(tmp_ptr[3]) : "r"(source_ptr[0]), "n"(0xBBB3));
// Convert s32x4 to f32x4 using fast numeric array converter
FastNumericArrayConverter<float, int32_t, 4, Round> convert_s32_to_f32_;
tmp = convert_s32_to_f32_(reinterpret_cast<intermediate_int32_type const &>(tmp[0]));
// Convert f32x2 to bf16x2 using `cvt.rn.b16x2.f32` instruction
NumericArrayConverter<cutlass::bfloat16_t, float, 4, Round> convert_f32_to_bf16_;
result = convert_f32_to_bf16_(tmp);
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const {
return convert(s);
}
};
/// Partial specialization for FastNumericArrayConverter to vectorize over 4 elements.
/// source `S` as 8b integers (S8 or U8) -> destination `T` as 16b floating-point (F16 or BF16)
template <typename T, typename S, int N, FloatRoundStyle Round>
struct FastNumericArrayConverter<T, S, N, Round,
typename platform::enable_if<(platform::is_same<T, half_t>::value || platform::is_same<T, bfloat16_t>::value) &&
(platform::is_same<S, int8_t>::value || platform::is_same<S, uint8_t>::value)>::type> {
static_assert(!(N % 4), "N must be multiple of 4.");
using result_type = Array<T, N>;
using source_type = Array<S, N>;
static FloatRoundStyle const round_style = Round;
CUTLASS_DEVICE
static result_type convert(source_type const &source) {
FastNumericArrayConverter<T, S, 4, Round> convert_vector_;
result_type result;
Array<T, 4> *result_ptr =
reinterpret_cast<Array<T, 4> *>(&result);
Array<S, 4> const *source_ptr =
reinterpret_cast<Array<S, 4> const *>(&source);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < N / 4; ++i) {
result_ptr[i] = convert_vector_(source_ptr[i]);
}
return result;
}
CUTLASS_DEVICE
result_type operator()(source_type const &s) const { return convert(s); }
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Defines preferred rounding mode for a pair of types