mirror of
https://github.com/NVIDIA/cutlass.git
synced 2026-05-12 09:15:56 +00:00
v4.1 release update v2. (#2481)
This commit is contained in:
@@ -177,7 +177,7 @@ struct WmmaToCutlassDataType<__nv_bfloat16> {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// WMMA template structure defines nvcuda::wmma::fragments and static assertion chaeks
|
||||
// for a specific template paramterized data type (Element[A|B|C]), layout (Layout[A|B|C]),
|
||||
// for a specific template parameterized data type (Element[A|B|C]), layout (Layout[A|B|C]),
|
||||
// and native wmma size (Shape)
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
template <
|
||||
|
||||
@@ -123,7 +123,7 @@ struct Wmma<
|
||||
nvcuda::wmma::mma_sync(D, A, B, C);
|
||||
}
|
||||
#else
|
||||
static_assert(false, "wmma.mma.sync for floating point multiplicands is avialable only for SM70 and beyond");
|
||||
static_assert(false, "wmma.mma.sync for floating point multiplicands is available only for SM70 and beyond");
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
@@ -117,7 +117,7 @@ struct Wmma<
|
||||
}
|
||||
|
||||
#else
|
||||
static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond");
|
||||
static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond");
|
||||
#endif
|
||||
|
||||
};
|
||||
@@ -197,7 +197,7 @@ struct Wmma<
|
||||
}
|
||||
|
||||
#else
|
||||
static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM72 and beyond");
|
||||
static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM72 and beyond");
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
@@ -115,7 +115,7 @@ struct Wmma<
|
||||
}
|
||||
|
||||
#else
|
||||
static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond");
|
||||
static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond");
|
||||
#endif
|
||||
|
||||
};
|
||||
@@ -194,7 +194,7 @@ struct Wmma<
|
||||
}
|
||||
|
||||
#else
|
||||
static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond");
|
||||
static_assert(false, "wmma.mma.sync integer type multiplicands is available only for SM75 and beyond");
|
||||
#endif
|
||||
|
||||
};
|
||||
|
||||
@@ -118,7 +118,7 @@ struct Array<T, N, false> {
|
||||
// result[0] = xxx;
|
||||
// ```
|
||||
//
|
||||
// Will leads to compiler warning on use of unintialized member variable. Although we know
|
||||
// Will leads to compiler warning on use of uninitialized member variable. Although we know
|
||||
// this read of uninitialized member variable is harmeless.
|
||||
|
||||
#if defined(__clang__)
|
||||
|
||||
@@ -1056,7 +1056,7 @@ struct DefaultConv2dFprop <
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
@@ -1184,7 +1184,7 @@ struct DefaultConv2dFprop <
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
// multistage pipeline with interleaved layout.
|
||||
template <
|
||||
typename ElementA,
|
||||
|
||||
@@ -215,7 +215,7 @@ struct DefaultConv2dFpropFusion <
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimzed IteratorAlgorithm and
|
||||
/// Defines a kernel for Conv2dFprop specialization for Optimized IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
|
||||
@@ -217,7 +217,7 @@ struct DefaultConv3dFpropFusion <
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Defines a kernel for Conv3dFprop specialzation for Optimzed IteratorAlgorithm and
|
||||
/// Defines a kernel for Conv3dFprop specialzation for Optimized IteratorAlgorithm and
|
||||
/// multistage pipeline.
|
||||
template <
|
||||
typename ElementA,
|
||||
|
||||
@@ -30,7 +30,7 @@
|
||||
**************************************************************************************************/
|
||||
|
||||
/*! \file
|
||||
\brief Interface betweeen a CUTLASS device-wide operator and CUDA.
|
||||
\brief Interface between a CUTLASS device-wide operator and CUDA.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
@@ -392,7 +392,7 @@ protected:
|
||||
|
||||
/**
|
||||
* Fills a buffer in Global Memory with a byte sequence copied from host memory.
|
||||
* This function can be overriden to dispatch to the appropriate cuMemsetD*Async API
|
||||
* This function can be overridden to dispatch to the appropriate cuMemsetD*Async API
|
||||
*/
|
||||
virtual Status memsetDeviceImpl(
|
||||
void* destination, ///< Device memory pointer to be filled
|
||||
|
||||
@@ -271,7 +271,7 @@ struct CollectiveBuilder<
|
||||
|
||||
// Construct TileShape for SFB load from GMEM to SMEM.
|
||||
// It is required to keep consistency with BlockScaled granularity defined in Sm1xxBlkScaledConfig.
|
||||
// So that TileShape for scaling factor needs to be defined as a mutliple of Blk_MN.
|
||||
// So that TileShape for scaling factor needs to be defined as a multiple of Blk_MN.
|
||||
using TileShapeSf_MNK = decltype(make_shape(ceil_div(size<0>(TileShape_MNK{}), Blk_MN{}) * Blk_MN{},
|
||||
ceil_div(size<1>(TileShape_MNK{}), Blk_MN{}) * Blk_MN{},
|
||||
shape<2>(TileShape_MNK{})));
|
||||
|
||||
@@ -153,13 +153,13 @@ struct CollectiveMma<
|
||||
// Asymmetric buffering
|
||||
// Tensor A/B could have different buffering, with TILEK, and STAGEs.
|
||||
// It let AsymmetricKRatio equals TILEK_A / TILEK_B, to make sure A/B's
|
||||
// pipeline keep same steps when procude / consume data.
|
||||
// pipeline keep same steps when produce / consume data.
|
||||
// Currently, AsymmetricKRatio = {1, 2} is the only support.
|
||||
static constexpr int AsymmetricKRatio = DispatchPolicy::StagesA != DispatchPolicy::StagesB ? 2 : 1;
|
||||
|
||||
// Construct TileShape for SFB load from GMEM to SMEM.
|
||||
// It is required to keep consistency with BlockScaled granularity defined in Sm1xxBlkScaledConfig.
|
||||
// So that TileShape for scaling factor needs to be defined as a mutliple of Blk_MN.
|
||||
// So that TileShape for scaling factor needs to be defined as a multiple of Blk_MN.
|
||||
using Blk_MN = typename Sm1xxBlkScaledConfig::Blk_MN;
|
||||
using TileShapeSF = decltype(make_shape(ceil_div(size<0>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{},
|
||||
ceil_div(size<1>(CtaShape_MNK{}), Blk_MN{}) * Blk_MN{},
|
||||
|
||||
@@ -136,7 +136,7 @@ struct CollectiveMma<
|
||||
// Asymmetric buffering
|
||||
// Tensor A/B could have different buffering, with TILEK, and STAGEs.
|
||||
// It let AsymmetricKRatio equals TILEK_A / TILEK_B, to make sure A/B's
|
||||
// pipeline keep same steps when procude / consume data.
|
||||
// pipeline keep same steps when produce / consume data.
|
||||
static constexpr int AsymmetricKRatio = DispatchPolicy::StagesA != DispatchPolicy::StagesB ? 2 : 1;
|
||||
|
||||
using TileShapeB = decltype(make_shape(size<0>(TileShape{}),
|
||||
|
||||
@@ -100,7 +100,7 @@ struct CollectiveMma<
|
||||
using TransformB = TransformB_;
|
||||
using ArchTag = typename DispatchPolicy::ArchTag;
|
||||
// Follow the change in TestSmall: TileShape switch to CtaShape
|
||||
// For sm80 arch, CtaShape should euqal to TileShape
|
||||
// For sm80 arch, CtaShape should equal to TileShape
|
||||
using CtaShape_MNK = TileShape;
|
||||
|
||||
static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)");
|
||||
|
||||
@@ -99,7 +99,7 @@ namespace device {
|
||||
|
||||
Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format:
|
||||
a_rows - Rows in the sparse matrix.
|
||||
a_cols - Colums in the sparse matrix.
|
||||
a_cols - Columns in the sparse matrix.
|
||||
BlockedEllA - Packed matrix (ellValue matrix) that stores non-zero values in
|
||||
consecutive blocks, whose size is (a_rows * a_ell_num_columns)
|
||||
ell_idx - Blocked-ELL Column indices (ellColInd) matrix, whose size is
|
||||
@@ -715,7 +715,7 @@ public:
|
||||
/// Constructs the GEMM.
|
||||
EllGemm() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying GEMM operator
|
||||
static UnderlyingArguments to_underlying_arguments(Arguments const &args) {
|
||||
return UnderlyingArguments(
|
||||
{args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
|
||||
|
||||
@@ -696,7 +696,7 @@ public:
|
||||
/// Constructs the GEMM.
|
||||
Gemm() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying GEMM operator
|
||||
static UnderlyingArguments to_underlying_arguments(Arguments const &args) {
|
||||
return UnderlyingArguments(
|
||||
{args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
|
||||
|
||||
@@ -653,7 +653,7 @@ public:
|
||||
/// Constructs the GEMM.
|
||||
GemmArray() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying GEMM operator
|
||||
static UnderlyingArguments to_underlying_arguments(Arguments const &args) {
|
||||
|
||||
GemmCoord problem_size{
|
||||
|
||||
@@ -626,7 +626,7 @@ public:
|
||||
/// Constructs the GEMM.
|
||||
GemmBatched() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying GEMM operator
|
||||
static UnderlyingArguments to_underlying_arguments(Arguments const &args) {
|
||||
return UnderlyingArguments(
|
||||
{args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
|
||||
|
||||
@@ -645,7 +645,7 @@ public:
|
||||
/// Constructs the GEMM.
|
||||
GemmComplex() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying GEMM operator
|
||||
static UnderlyingArguments to_underlying_arguments(Arguments const &args) {
|
||||
return UnderlyingArguments(
|
||||
{args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
|
||||
|
||||
@@ -561,7 +561,7 @@ public:
|
||||
/// Constructs the GEMM.
|
||||
GemmSplitKParallel() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying GEMM operator
|
||||
static UnderlyingArguments to_underlying_arguments(Arguments const &args) {
|
||||
return UnderlyingArguments(
|
||||
{args.problem_size.n(), args.problem_size.m(), args.problem_size.k()},
|
||||
|
||||
@@ -367,7 +367,7 @@ public:
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversal() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying GEMM operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem();
|
||||
}
|
||||
|
||||
@@ -693,7 +693,7 @@ public:
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalAdapter() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying GEMM operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
if (kInternalTranspose) {
|
||||
return args.transposed_problem();
|
||||
|
||||
@@ -311,7 +311,7 @@ public:
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalStreamkWithBroadcast() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying GEMM operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem();
|
||||
}
|
||||
|
||||
@@ -329,7 +329,7 @@ public:
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalWithAbsMax() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying GEMM operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem();
|
||||
}
|
||||
|
||||
@@ -311,7 +311,7 @@ public:
|
||||
/// Constructs the GEMM.
|
||||
GemmUniversalWithBroadcast() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying GEMM operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem();
|
||||
}
|
||||
|
||||
@@ -340,7 +340,7 @@ public:
|
||||
/// Constructs the GEMM.
|
||||
GemmWithKReduction() = default;
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying GEMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying GEMM operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem();
|
||||
}
|
||||
|
||||
@@ -473,7 +473,7 @@ public:
|
||||
/// Constructs the Rank2K.
|
||||
Rank2K() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying Rank2K operator
|
||||
/// Helper to construct a transposed equivalent for the underlying Rank2K operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem();
|
||||
}
|
||||
|
||||
@@ -436,7 +436,7 @@ public:
|
||||
/// Constructs the RankK.
|
||||
RankK() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying RankK operator
|
||||
/// Helper to construct a transposed equivalent for the underlying RankK operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args;
|
||||
}
|
||||
|
||||
@@ -528,7 +528,7 @@ public:
|
||||
/// Constructs the Symm.
|
||||
Symm() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying SYMM operator
|
||||
/// Helper to construct a transposed equivalent for the underlying SYMM operator
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem_size();
|
||||
}
|
||||
|
||||
@@ -300,7 +300,7 @@ class Trmm {
|
||||
static int const kAlignmentBKernel = (SideModeA == SideMode::kRight) ? AlignmentA : AlignmentB;
|
||||
static int const kAlignmentC = EpilogueOutputOp::kCount;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
// Complex Transform don't appply to B
|
||||
// Complex Transform don't apply to B
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
static ComplexTransform const kTransformAKernel = (SideModeA == SideMode::kRight) ?
|
||||
@@ -651,7 +651,7 @@ class Trmm<ElementA_, LayoutA_, SideModeA, FillModeA, DiagTypeA,
|
||||
static int const kAlignmentA = AlignmentA;
|
||||
static int const kAlignmentB = AlignmentB;
|
||||
static ComplexTransform const kTransformA = TransformA;
|
||||
// Complex Transform don't appply to B
|
||||
// Complex Transform don't apply to B
|
||||
static ComplexTransform const kTransformB = ComplexTransform::kNone;
|
||||
static bool const kSplitKSerial = SplitKSerial;
|
||||
|
||||
@@ -694,7 +694,7 @@ public:
|
||||
/// Constructs the TRMM.
|
||||
Trmm() { }
|
||||
|
||||
/// Helper to construct a transposed equivalent for the underying TRMM operator which is identical
|
||||
/// Helper to construct a transposed equivalent for the underlying TRMM operator which is identical
|
||||
static Arguments to_underlying_arguments(Arguments const &args) {
|
||||
return args.transposed_problem_size();
|
||||
}
|
||||
|
||||
@@ -160,7 +160,7 @@ struct DefaultSymmComplex<
|
||||
Operator, SplitKSerial, BlasMode::kSymmetric> {
|
||||
|
||||
static BlasMode const kBlasMode = BlasMode::kSymmetric;
|
||||
// Complex Transform don't appply to A or B for SYMM
|
||||
// Complex Transform don't apply to A or B for SYMM
|
||||
static ComplexTransform const TransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const TransformB = ComplexTransform::kNone;
|
||||
|
||||
@@ -353,7 +353,7 @@ struct DefaultSymmComplex<
|
||||
Operator, SplitKSerial, BlasMode::kSymmetric> {
|
||||
|
||||
static BlasMode const kBlasMode = BlasMode::kSymmetric;
|
||||
// Complex Transform don't appply to A or B for SYMM
|
||||
// Complex Transform don't apply to A or B for SYMM
|
||||
static ComplexTransform const TransformA = ComplexTransform::kNone;
|
||||
static ComplexTransform const TransformB = ComplexTransform::kNone;
|
||||
|
||||
|
||||
@@ -70,7 +70,7 @@ namespace detail
|
||||
using CDType = typename FragmentCD::value_type;
|
||||
|
||||
static_assert(FragmentCD::kElements == FragmentAccumulator::kElements,
|
||||
"Mistmatch in fragment sizes.");
|
||||
"Mismatch in fragment sizes.");
|
||||
|
||||
for (int i = 0; i < FragmentCD::kElements; ++i)
|
||||
{
|
||||
|
||||
@@ -52,7 +52,7 @@ namespace cutlass::gemm::kernel::detail {
|
||||
// Therefore, we don't how many tiles there will be for the scheduler to hand out.
|
||||
// Hence, we have a SM90 style static group scheduler that launches the largest grid possible.
|
||||
// If we had access to host-side problem shapes, one could to use it to figure out the grid shape
|
||||
// and thereafter use CLC query (which can then be linearized and mapped to an approriate tile coord).
|
||||
// and thereafter use CLC query (which can then be linearized and mapped to an appropriate tile coord).
|
||||
|
||||
template<class GroupProblemShape, int SchedulerPipelineStageCount>
|
||||
class PersistentTileSchedulerSm100Group {
|
||||
|
||||
@@ -728,7 +728,7 @@ private:
|
||||
auto cluster_start_linear_id = sm_count * wave_idx + cluster_idx;
|
||||
|
||||
// Determine the offset of this CTA in the preferred cluster shape.
|
||||
// This calculation aims to accomodate both cases in which this CTA is part of a preferred cluster
|
||||
// This calculation aims to accommodate both cases in which this CTA is part of a preferred cluster
|
||||
// and those in which it is part of a fallback cluster.
|
||||
//
|
||||
// The calculation is performed by computing the starting M and N index of the preferred cluster that
|
||||
|
||||
@@ -120,7 +120,7 @@ public:
|
||||
// Tensor A/B could have different buffering, with number of KBLOCK, aka TILEK,
|
||||
// and STAGEs. It let AsymmetricKRatio, equals KBLOCK_A / KBLOCK_B, to control
|
||||
// the balance of A/B loading, make sure A/B's pipeline keep same cadence
|
||||
// when procude / consume data.
|
||||
// when produce / consume data.
|
||||
// Currently, AsymmetricKRatio = {1, 2} is the only support.
|
||||
static constexpr bool isAsymmetric = DispatchPolicy::Schedule::isAsymmetric;
|
||||
static constexpr uint32_t AsymmetricKRatio = isAsymmetric ? 2 : 1;
|
||||
|
||||
@@ -409,7 +409,7 @@ struct PersistentTileSchedulerSm90StreamKParams {
|
||||
FastDivmodU64 divmod_clusters_mnl_{};
|
||||
|
||||
// We divide up the number of stream-K tiles amongst G groups of stream-K units.
|
||||
// The stream-K units within a group collaborate to comptue over the `sk_tiles / G`
|
||||
// The stream-K units within a group collaborate to compute over the `sk_tiles / G`
|
||||
// tiles assigned to that group. Non-unit group sizes can help to preserve L2 locality of
|
||||
// partial chunks computed by stream-K units -- units 0 in each group will compute identical K extents
|
||||
// of tiles that would be assigned in the same wave according to the rasterization order of the
|
||||
@@ -932,7 +932,7 @@ struct PersistentTileSchedulerSm90StreamKParams {
|
||||
}
|
||||
}
|
||||
|
||||
// Given decomposition mode output from heuristic, set all feilds of params.
|
||||
// Given decomposition mode output from heuristic, set all fields of params.
|
||||
void set_params(
|
||||
DecompositionMode heuristic_mode,
|
||||
uint32_t groups,
|
||||
@@ -954,7 +954,7 @@ struct PersistentTileSchedulerSm90StreamKParams {
|
||||
, uint32_t ktile_start_alignment_count
|
||||
) {
|
||||
// The highest priority when customers set as splitk mode, may set
|
||||
// with a adpated splits value rather than the original splits
|
||||
// with a adapted splits value rather than the original splits
|
||||
// even it does not make sense
|
||||
if (splits > 1 && heuristic_mode == DecompositionMode::SplitK) {
|
||||
set_params_basic(
|
||||
|
||||
@@ -94,7 +94,7 @@ template <
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
@@ -364,7 +364,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultEllMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
@@ -429,7 +429,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultEllMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
|
||||
@@ -91,7 +91,7 @@ template <
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
@@ -417,7 +417,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Gather operand A by using an index array
|
||||
bool GatherA,
|
||||
@@ -498,7 +498,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Use zfill or predicate for out-of-bound cp.async
|
||||
SharedMemoryClearOption SharedMemoryClear,
|
||||
|
||||
@@ -38,7 +38,7 @@
|
||||
instructions.
|
||||
|
||||
SM80 Multi stage kernel expects stage number to be larger or equal to 3
|
||||
to use asyncronous copy.
|
||||
to use asynchronous copy.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
@@ -88,7 +88,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
|
||||
@@ -91,7 +91,7 @@ template <
|
||||
/// Whether problem has been transformed. This determines to which operand
|
||||
/// the softmax is applied.
|
||||
bool InternalTranspose,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
|
||||
@@ -82,7 +82,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
|
||||
@@ -87,7 +87,7 @@ template <
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
@@ -123,7 +123,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultSparseMma<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB,
|
||||
|
||||
@@ -96,7 +96,7 @@ template <
|
||||
typename InstructionShape_,
|
||||
/// Number of stages used in the pipelined mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator,
|
||||
/// Store the accumulators in row major or column major. Row major is used
|
||||
/// when output layout is interleaved.
|
||||
@@ -138,7 +138,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultTrmm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
@@ -221,7 +221,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultTrmm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
@@ -304,7 +304,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultTrmm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
@@ -385,7 +385,7 @@ template <
|
||||
typename InstructionShape,
|
||||
/// Number of stages used in the multistage mainloop
|
||||
int Stages,
|
||||
/// Operation perfomed by GEMM
|
||||
/// Operation performed by GEMM
|
||||
typename Operator
|
||||
>
|
||||
struct DefaultTrmm<ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB,
|
||||
|
||||
@@ -92,7 +92,7 @@ public:
|
||||
FragmentC &accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
FragmentC const &src_accum) { ///< source accumualtor tile
|
||||
FragmentC const &src_accum) { ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Prologue
|
||||
|
||||
@@ -115,7 +115,7 @@ public:
|
||||
///< Policy describing tuning details
|
||||
using Policy = Policy_;
|
||||
|
||||
///< Archtecture tag
|
||||
///< Architecture tag
|
||||
using ArchTag = arch::Sm80;
|
||||
|
||||
using SmemIteratorA = SmemIteratorA_;
|
||||
|
||||
@@ -173,7 +173,7 @@ public:
|
||||
FragmentC &accum, ///< destination accumulator tile
|
||||
IteratorA iterator_A, ///< iterator over A operand in global memory
|
||||
IteratorB iterator_B, ///< iterator over B operand in global memory
|
||||
FragmentC const &src_accum) { ///< source accumualtor tile
|
||||
FragmentC const &src_accum) { ///< source accumulator tile
|
||||
|
||||
//
|
||||
// Prologue
|
||||
|
||||
@@ -513,7 +513,7 @@ struct ThreadblockSwizzleStreamK {
|
||||
// - More than three peers working on an SK tile. (This occurs when the ratio of
|
||||
// SK-blocks to SK-tiles > 2, as a single tile may be covered by four SK-blocks,
|
||||
// e.g.:[partial-block | block | block | partial-block] ). With three or
|
||||
// less peers, the two non-finishing SK-blocks are not expexted to contend.
|
||||
// less peers, the two non-finishing SK-blocks are not expected to contend.
|
||||
if ((kReductionStrategy == kMixed) &&
|
||||
(sk_waves < sm_occupancy) &&
|
||||
(sk_blocks > 2 * sk_tiles))
|
||||
|
||||
@@ -782,7 +782,7 @@ public:
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n) {
|
||||
|
||||
// negate OperandB to accumulate -(a.imag()*b.imag())
|
||||
// negating OperandB emits less instrucitons than negating OperandA as OperandB has less elements
|
||||
// negating OperandB emits less instructions than negating OperandA as OperandB has less elements
|
||||
negate<InstMmaOperandB> negate_op;
|
||||
|
||||
// Real-valued accumulator part
|
||||
|
||||
@@ -598,7 +598,7 @@ public:
|
||||
for (int n = 0; n < MmaIterations::kColumn; ++n) {
|
||||
|
||||
// negate OperandB to accumulate -(a.imag()*b.imag())
|
||||
// negating OperandB emits less instrucitons than negating OperandA as OperandB has less elements
|
||||
// negating OperandB emits less instructions than negating OperandA as OperandB has less elements
|
||||
negate<InstMmaOperandB> negate_op;
|
||||
|
||||
// Real-valued accumulator part
|
||||
|
||||
@@ -427,7 +427,7 @@ public:
|
||||
using TransformedFragmentA =
|
||||
Array<ElementAMma, FragmentA::kElements>;
|
||||
|
||||
/// Underlying arch::Mma instruction operand fragement for matrix A
|
||||
/// Underlying arch::Mma instruction operand fragment for matrix A
|
||||
using MmaOperandA = typename ArchMmaOperator::FragmentA;
|
||||
|
||||
/// Iterates over the B operand in Shared Memory
|
||||
@@ -443,7 +443,7 @@ public:
|
||||
using TransformedFragmentB =
|
||||
Array<ElementBMma, FragmentB::kElements>;
|
||||
|
||||
/// Underlying arch::Mma instruction operand fragement for matrix B
|
||||
/// Underlying arch::Mma instruction operand fragment for matrix B
|
||||
using MmaOperandB = typename ArchMmaOperator::FragmentB;
|
||||
|
||||
/// Iterates over the C operand in memory
|
||||
@@ -454,7 +454,7 @@ public:
|
||||
/// Storage for C tile
|
||||
using FragmentC = typename IteratorC::Fragment;
|
||||
|
||||
/// Underlying arch::Mma instruction operand fragement for matrix C
|
||||
/// Underlying arch::Mma instruction operand fragment for matrix C
|
||||
using MmaOperandC = typename ArchMmaOperator::FragmentC;
|
||||
|
||||
/// Number of mma operations performed
|
||||
|
||||
@@ -117,7 +117,7 @@ public:
|
||||
/// Shape of the warp in units of thread (concept: MmaLanePolicySimt)
|
||||
using Policy = Policy_;
|
||||
|
||||
/// Equivalant base dense mma
|
||||
/// Equivalent base dense mma
|
||||
using Base = MmaTensorOp<Shape, ElementA, LayoutA, ElementB, LayoutB,
|
||||
ElementC, LayoutC, Policy, PartitionsK_,
|
||||
AccumulatorsInRowMajor, Enable>;
|
||||
|
||||
@@ -33,7 +33,7 @@
|
||||
\brief This defines a "fragment" iterator for visiting the fragments of a warp tile
|
||||
that participate in one warp-level mma operation.
|
||||
|
||||
Typically, this is used to access the accumulator tile/fragement of a warp-level mma operation.
|
||||
Typically, this is used to access the accumulator tile/fragment of a warp-level mma operation.
|
||||
The accumulator tile is then partitioned into smaller tiles/fragments that can be fed into
|
||||
next warp-level mma operation.
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ namespace warp {
|
||||
|
||||
|
||||
/// Tile access iterator
|
||||
/// Each iteration acess in the tile is
|
||||
/// Each iteration access in the tile is
|
||||
/// used as multiplicand for one
|
||||
/// warp-level matrix multiplication
|
||||
template <
|
||||
|
||||
@@ -68,7 +68,7 @@
|
||||
|
||||
///////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
// Optionally target F16C extentions to accelerate half-precision conversion.
|
||||
// Optionally target F16C extensions to accelerate half-precision conversion.
|
||||
#if !defined(__CUDA_ARCH__) && (CUTLASS_ENABLE_F16C)
|
||||
#if defined(_MSC_VER)
|
||||
|
||||
|
||||
@@ -90,7 +90,7 @@ struct KernelHardwareInfo {
|
||||
int max_active_clusters = 0;
|
||||
#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED)
|
||||
ClusterLauncher::LaunchConfig cluster_launch_config = ClusterLauncher::make_cluster_launch_config(
|
||||
cluster_dims /* minumum grid dim */, cluster_dims, {threads_per_block, 1, 1});
|
||||
cluster_dims /* minimum grid dim */, cluster_dims, {threads_per_block, 1, 1});
|
||||
// Given the kernel function and launch configuration, return the maximum number of clusters that could co-exist on the target device.
|
||||
cudaError_t result = cudaOccupancyMaxActiveClusters(&max_active_clusters, kernel_ptr, &cluster_launch_config.launch_config);
|
||||
if (result != cudaSuccess) {
|
||||
|
||||
@@ -101,7 +101,7 @@ struct Matrix<Element_, 1, 2> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 1-by-2 matrix from scalar elements
|
||||
/// Constructs a 1-by-2 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0, Element _0_1
|
||||
@@ -599,7 +599,7 @@ template <typename Element>
|
||||
using Matrix1x2 = Matrix<Element, 1, 2>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix1x2<Element> make_Matrix1x2(
|
||||
Element _0_0, Element _0_1
|
||||
@@ -658,7 +658,7 @@ struct Matrix<Element_, 1, 3> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 1-by-3 matrix from scalar elements
|
||||
/// Constructs a 1-by-3 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0, Element _0_1, Element _0_2
|
||||
@@ -1226,7 +1226,7 @@ template <typename Element>
|
||||
using Matrix1x3 = Matrix<Element, 1, 3>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix1x3<Element> make_Matrix1x3(
|
||||
Element _0_0, Element _0_1, Element _0_2
|
||||
@@ -1285,7 +1285,7 @@ struct Matrix<Element_, 1, 4> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 1-by-4 matrix from scalar elements
|
||||
/// Constructs a 1-by-4 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0, Element _0_1, Element _0_2, Element _0_3
|
||||
@@ -1905,7 +1905,7 @@ template <typename Element>
|
||||
using Matrix1x4 = Matrix<Element, 1, 4>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix1x4<Element> make_Matrix1x4(
|
||||
Element _0_0, Element _0_1, Element _0_2, Element _0_3
|
||||
@@ -1964,7 +1964,7 @@ struct Matrix<Element_, 2, 1> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 2-by-1 matrix from scalar elements
|
||||
/// Constructs a 2-by-1 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0,
|
||||
@@ -2471,7 +2471,7 @@ template <typename Element>
|
||||
using Matrix2x1 = Matrix<Element, 2, 1>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix2x1<Element> make_Matrix2x1(
|
||||
Element _0_0,
|
||||
@@ -2532,7 +2532,7 @@ struct Matrix<Element_, 2, 2> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 2-by-2 matrix from scalar elements
|
||||
/// Constructs a 2-by-2 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0, Element _0_1,
|
||||
@@ -2543,7 +2543,7 @@ struct Matrix<Element_, 2, 2> {
|
||||
data[2] = _1_0; data[3] = _1_1;
|
||||
}
|
||||
|
||||
/// Constucts a 2-by-2 matrix from row vectors
|
||||
/// Constructs a 2-by-2 matrix from row vectors
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Matrix<Element, 1, 2> const &row_0,
|
||||
@@ -3258,7 +3258,7 @@ template <typename Element>
|
||||
using Matrix2x2 = Matrix<Element, 2, 2>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix2x2<Element> make_Matrix2x2(
|
||||
Element _0_0, Element _0_1,
|
||||
@@ -3319,7 +3319,7 @@ struct Matrix<Element_, 2, 3> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 2-by-3 matrix from scalar elements
|
||||
/// Constructs a 2-by-3 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0, Element _0_1, Element _0_2,
|
||||
@@ -3330,7 +3330,7 @@ struct Matrix<Element_, 2, 3> {
|
||||
data[3] = _1_0; data[4] = _1_1; data[5] = _1_2;
|
||||
}
|
||||
|
||||
/// Constucts a 2-by-3 matrix from row vectors
|
||||
/// Constructs a 2-by-3 matrix from row vectors
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Matrix<Element, 1, 3> const &row_0,
|
||||
@@ -4128,7 +4128,7 @@ template <typename Element>
|
||||
using Matrix2x3 = Matrix<Element, 2, 3>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix2x3<Element> make_Matrix2x3(
|
||||
Element _0_0, Element _0_1, Element _0_2,
|
||||
@@ -4189,7 +4189,7 @@ struct Matrix<Element_, 2, 4> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 2-by-4 matrix from scalar elements
|
||||
/// Constructs a 2-by-4 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0, Element _0_1, Element _0_2, Element _0_3,
|
||||
@@ -4200,7 +4200,7 @@ struct Matrix<Element_, 2, 4> {
|
||||
data[4] = _1_0; data[5] = _1_1; data[6] = _1_2; data[7] = _1_3;
|
||||
}
|
||||
|
||||
/// Constucts a 2-by-4 matrix from row vectors
|
||||
/// Constructs a 2-by-4 matrix from row vectors
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Matrix<Element, 1, 4> const &row_0,
|
||||
@@ -5134,7 +5134,7 @@ template <typename Element>
|
||||
using Matrix2x4 = Matrix<Element, 2, 4>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix2x4<Element> make_Matrix2x4(
|
||||
Element _0_0, Element _0_1, Element _0_2, Element _0_3,
|
||||
@@ -5195,7 +5195,7 @@ struct Matrix<Element_, 3, 1> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 3-by-1 matrix from scalar elements
|
||||
/// Constructs a 3-by-1 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0,
|
||||
@@ -5780,7 +5780,7 @@ template <typename Element>
|
||||
using Matrix3x1 = Matrix<Element, 3, 1>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix3x1<Element> make_Matrix3x1(
|
||||
Element _0_0,
|
||||
@@ -5843,7 +5843,7 @@ struct Matrix<Element_, 3, 2> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 3-by-2 matrix from scalar elements
|
||||
/// Constructs a 3-by-2 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0, Element _0_1,
|
||||
@@ -5856,7 +5856,7 @@ struct Matrix<Element_, 3, 2> {
|
||||
data[4] = _2_0; data[5] = _2_1;
|
||||
}
|
||||
|
||||
/// Constucts a 3-by-2 matrix from row vectors
|
||||
/// Constructs a 3-by-2 matrix from row vectors
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Matrix<Element, 1, 2> const &row_0,
|
||||
@@ -6665,7 +6665,7 @@ template <typename Element>
|
||||
using Matrix3x2 = Matrix<Element, 3, 2>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix3x2<Element> make_Matrix3x2(
|
||||
Element _0_0, Element _0_1,
|
||||
@@ -6728,7 +6728,7 @@ struct Matrix<Element_, 3, 3> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 3-by-3 matrix from scalar elements
|
||||
/// Constructs a 3-by-3 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0, Element _0_1, Element _0_2,
|
||||
@@ -6741,7 +6741,7 @@ struct Matrix<Element_, 3, 3> {
|
||||
data[6] = _2_0; data[7] = _2_1; data[8] = _2_2;
|
||||
}
|
||||
|
||||
/// Constucts a 3-by-3 matrix from row vectors
|
||||
/// Constructs a 3-by-3 matrix from row vectors
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Matrix<Element, 1, 3> const &row_0,
|
||||
@@ -7896,7 +7896,7 @@ template <typename Element>
|
||||
using Matrix3x3 = Matrix<Element, 3, 3>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix3x3<Element> make_Matrix3x3(
|
||||
Element _0_0, Element _0_1, Element _0_2,
|
||||
@@ -7959,7 +7959,7 @@ struct Matrix<Element_, 3, 4> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 3-by-4 matrix from scalar elements
|
||||
/// Constructs a 3-by-4 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0, Element _0_1, Element _0_2, Element _0_3,
|
||||
@@ -7972,7 +7972,7 @@ struct Matrix<Element_, 3, 4> {
|
||||
data[8] = _2_0; data[9] = _2_1; data[10] = _2_2; data[11] = _2_3;
|
||||
}
|
||||
|
||||
/// Constucts a 3-by-4 matrix from row vectors
|
||||
/// Constructs a 3-by-4 matrix from row vectors
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Matrix<Element, 1, 4> const &row_0,
|
||||
@@ -9208,7 +9208,7 @@ template <typename Element>
|
||||
using Matrix3x4 = Matrix<Element, 3, 4>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix3x4<Element> make_Matrix3x4(
|
||||
Element _0_0, Element _0_1, Element _0_2, Element _0_3,
|
||||
@@ -9271,7 +9271,7 @@ struct Matrix<Element_, 4, 1> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 4-by-1 matrix from scalar elements
|
||||
/// Constructs a 4-by-1 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0,
|
||||
@@ -9918,7 +9918,7 @@ template <typename Element>
|
||||
using Matrix4x1 = Matrix<Element, 4, 1>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix4x1<Element> make_Matrix4x1(
|
||||
Element _0_0,
|
||||
@@ -9983,7 +9983,7 @@ struct Matrix<Element_, 4, 2> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 4-by-2 matrix from scalar elements
|
||||
/// Constructs a 4-by-2 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0, Element _0_1,
|
||||
@@ -9998,7 +9998,7 @@ struct Matrix<Element_, 4, 2> {
|
||||
data[6] = _3_0; data[7] = _3_1;
|
||||
}
|
||||
|
||||
/// Constucts a 4-by-2 matrix from row vectors
|
||||
/// Constructs a 4-by-2 matrix from row vectors
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Matrix<Element, 1, 2> const &row_0,
|
||||
@@ -10958,7 +10958,7 @@ template <typename Element>
|
||||
using Matrix4x2 = Matrix<Element, 4, 2>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix4x2<Element> make_Matrix4x2(
|
||||
Element _0_0, Element _0_1,
|
||||
@@ -11023,7 +11023,7 @@ struct Matrix<Element_, 4, 3> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 4-by-3 matrix from scalar elements
|
||||
/// Constructs a 4-by-3 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0, Element _0_1, Element _0_2,
|
||||
@@ -11038,7 +11038,7 @@ struct Matrix<Element_, 4, 3> {
|
||||
data[9] = _3_0; data[10] = _3_1; data[11] = _3_2;
|
||||
}
|
||||
|
||||
/// Constucts a 4-by-3 matrix from row vectors
|
||||
/// Constructs a 4-by-3 matrix from row vectors
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Matrix<Element, 1, 3> const &row_0,
|
||||
@@ -12291,7 +12291,7 @@ template <typename Element>
|
||||
using Matrix4x3 = Matrix<Element, 4, 3>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix4x3<Element> make_Matrix4x3(
|
||||
Element _0_0, Element _0_1, Element _0_2,
|
||||
@@ -12356,7 +12356,7 @@ struct Matrix<Element_, 4, 4> {
|
||||
data = rhs.data;
|
||||
}
|
||||
|
||||
/// Constucts a 4-by-4 matrix from scalar elements
|
||||
/// Constructs a 4-by-4 matrix from scalar elements
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Element _0_0, Element _0_1, Element _0_2, Element _0_3,
|
||||
@@ -12371,7 +12371,7 @@ struct Matrix<Element_, 4, 4> {
|
||||
data[12] = _3_0; data[13] = _3_1; data[14] = _3_2; data[15] = _3_3;
|
||||
}
|
||||
|
||||
/// Constucts a 4-by-4 matrix from row vectors
|
||||
/// Constructs a 4-by-4 matrix from row vectors
|
||||
CUTLASS_HOST_DEVICE
|
||||
Matrix(
|
||||
Matrix<Element, 1, 4> const &row_0,
|
||||
@@ -14096,7 +14096,7 @@ template <typename Element>
|
||||
using Matrix4x4 = Matrix<Element, 4, 4>;
|
||||
|
||||
|
||||
/// Free funciton to infer element type from template arguments
|
||||
/// Free function to infer element type from template arguments
|
||||
template <typename Element>
|
||||
CUTLASS_HOST_DEVICE Matrix4x4<Element> make_Matrix4x4(
|
||||
Element _0_0, Element _0_1, Element _0_2, Element _0_3,
|
||||
|
||||
@@ -51,7 +51,7 @@ namespace cutlass {
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
/// Floating-point rounding style similare to Standard Library's formats but supporting
|
||||
/// Floating-point rounding style similar to Standard Library's formats but supporting
|
||||
/// additional rounding options.
|
||||
enum class FloatRoundStyle {
|
||||
round_indeterminate, ///< rounding mode unknown
|
||||
@@ -6175,7 +6175,7 @@ private:
|
||||
asm volatile("prmt.b32 %0,%1,%1,%2;\n" : "=r"(r[ii]) : "r"(src_reg), "r"(prmt_indices[ii]));
|
||||
}
|
||||
|
||||
// In the absense of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve
|
||||
// In the absence of add.s16x2 instruction, use bit-wise operation to execute signed addition with magic numbers to achieve
|
||||
// the same result as add.s16x2 instruction.
|
||||
// (See https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#logic-and-shift-instructions-lop3)
|
||||
// For a logical operation F(a, b, c) the value of kImmLut can be computed by applying the same operation to
|
||||
|
||||
@@ -289,7 +289,7 @@ public:
|
||||
static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1);
|
||||
|
||||
static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile;
|
||||
static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invaild warp thread shape." );
|
||||
static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." );
|
||||
static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step.
|
||||
static constexpr int64_t WarpTileNCoordLUT = 06723763275316420;
|
||||
static constexpr int64_t WarpTileKCoordLUT = 05410541064206420;
|
||||
@@ -510,7 +510,7 @@ public:
|
||||
static constexpr int NumWarpTilePerWarpgroupTile = NumWarpsPerWarpGroup * (Steps == 8 ? 2 : 1);
|
||||
|
||||
static constexpr int WarpTileSize = WarpgroupTileSize / NumWarpTilePerWarpgroupTile;
|
||||
static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invaild warp thread shape." );
|
||||
static_assert(WarpTileSize >= WarpThreadShapeN && WarpTileSize >= WarpThreadShapeK, "Invalid warp thread shape." );
|
||||
static constexpr int TilesPerWarp = 2; // Each Warp would process 2 warp_tiles in one step.
|
||||
static constexpr int64_t WarpTileNCoordLUT = 06723763275316420;
|
||||
static constexpr int64_t WarpTileKCoordLUT = 05410541064206420;
|
||||
|
||||
@@ -298,7 +298,7 @@ struct PitchLinearWarpRakedThreadMap {
|
||||
static_assert(Iterations::kCount,
|
||||
"Number of iterations must be non-zero");
|
||||
|
||||
///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
|
||||
///< Delta between accesses (units of elements, concept: PitchLinearShape)
|
||||
using Delta = layout::PitchLinearShape<
|
||||
Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess,
|
||||
Detail::WarpThreadArrangement::kStrided
|
||||
@@ -427,7 +427,7 @@ struct PitchLinearStridedWarpRakedThreadMap {
|
||||
static_assert(Iterations::kCount,
|
||||
"Number of iterations must be non-zero");
|
||||
|
||||
///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
|
||||
///< Delta between accesses (units of elements, concept: PitchLinearShape)
|
||||
using Delta = typename BaseThreadMap::Delta;
|
||||
|
||||
/// Maps thread ID to a coordinate offset within the tensor's logical coordinate space
|
||||
@@ -531,7 +531,7 @@ struct TransposePitchLinearThreadMap {
|
||||
|
||||
static_assert(Iterations::kCount, "Number of iterations must be non-zero");
|
||||
|
||||
///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
|
||||
///< Delta between accesses (units of elements, concept: PitchLinearShape)
|
||||
using Delta =
|
||||
layout::PitchLinearShape<Detail::WarpThreadArrangement::kContiguous *
|
||||
kElementsPerAccess,
|
||||
@@ -613,7 +613,7 @@ struct TransposePitchLinearThreadMapSimt {
|
||||
/// Shape of access by each thread
|
||||
using ThreadAccessShape = typename ThreadMap::ThreadAccessShape;
|
||||
|
||||
///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
|
||||
///< Delta between accesses (units of elements, concept: PitchLinearShape)
|
||||
using Delta =
|
||||
layout::PitchLinearShape<ThreadMap::Delta::kStrided,
|
||||
ThreadMap::Delta::kContiguous>;
|
||||
@@ -716,7 +716,7 @@ struct PitchLinearWarpStripedThreadMap {
|
||||
static_assert(Iterations::kCount,
|
||||
"Number of iterations must be non-zero");
|
||||
|
||||
///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
|
||||
///< Delta between accesses (units of elements, concept: PitchLinearShape)
|
||||
using Delta = layout::PitchLinearShape<
|
||||
Detail::WarpThreadArrangement::kContiguous * kElementsPerAccess,
|
||||
Detail::WarpThreadArrangement::kStrided * Detail::WarpArrangement::kStrided
|
||||
@@ -897,7 +897,7 @@ struct TransposePitchLinearThreadMap2DThreadTile {
|
||||
/// Shape of access by each thread
|
||||
using ThreadAccessShape = typename ThreadMap::ThreadAccessShape;
|
||||
|
||||
///< Delta betweeen accesses (units of elements, concept: PitchLinearShape)
|
||||
///< Delta between accesses (units of elements, concept: PitchLinearShape)
|
||||
using Delta =
|
||||
layout::PitchLinearShape<ThreadMap::Delta::kStrided,
|
||||
ThreadMap::Delta::kContiguous>;
|
||||
|
||||
@@ -76,7 +76,7 @@ namespace threadblock {
|
||||
/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
|
||||
/// outside any looping structure to minimize integer arithmetic.
|
||||
///
|
||||
/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
||||
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
||||
/// the iterator.
|
||||
///
|
||||
///
|
||||
|
||||
@@ -419,7 +419,7 @@ class PredicatedTileAccessIterator<Shape_, Element_, layout::PitchLinear,
|
||||
/// Tracks thread's coordinate offset in the matrix for current tile.
|
||||
/// This is only used in the following cases:
|
||||
/// - when Gather is true, strided coordinate needed to access indices (contiguous offset is tracked via pointer_)
|
||||
/// - when Permute is true, both coordinates are neeeded as input into permutation function (pointer_ is fixed)
|
||||
/// - when Permute is true, both coordinates are needed as input into permutation function (pointer_ is fixed)
|
||||
TensorCoord coord_offset_;
|
||||
|
||||
private:
|
||||
|
||||
@@ -81,7 +81,7 @@ namespace threadblock {
|
||||
/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
|
||||
/// outside any looping structure to minimize integer arithmetic.
|
||||
///
|
||||
/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
||||
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
||||
/// the iterator.
|
||||
///
|
||||
///
|
||||
|
||||
@@ -76,10 +76,10 @@ namespace threadblock {
|
||||
/// accesses may be performed without updating internal predicates and are efficient in terms of
|
||||
/// live register state and pointer arithmetic instructions.
|
||||
///
|
||||
/// To be efficient, this assumes the iteraor will be dereferenced and advanced at least once
|
||||
/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
|
||||
/// outside any looping structure to minimize integer arithmetic.
|
||||
///
|
||||
/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
||||
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
||||
/// the iterator.
|
||||
///
|
||||
///
|
||||
@@ -181,7 +181,7 @@ class PredicatedTileIterator2dThreadTile<Shape_, Element_, layout::PitchLinear,
|
||||
static int const kElements = ThreadMap::kElementsPerAccess;
|
||||
};
|
||||
|
||||
/// Optinally this fragment can be 4x4 transposed
|
||||
/// Optionally this fragment can be 4x4 transposed
|
||||
using Transform = thread::Transpose< ThreadMap::Iterations::kCount * ThreadMap::ThreadAccessShape::kCount , layout::PitchLinearShape<4,4>, Element>;
|
||||
static bool const transpose = Transpose_;
|
||||
|
||||
|
||||
@@ -76,10 +76,10 @@ namespace threadblock {
|
||||
/// accesses may be performed without updating internal predicates and are efficient in terms of
|
||||
/// live register state and pointer arithmetic instructions.
|
||||
///
|
||||
/// To be efficient, this assumes the iteraor will be dereferenced and advanced at least once
|
||||
/// To be efficient, this assumes the iterator will be dereferenced and advanced at least once
|
||||
/// outside any looping structure to minimize integer arithmetic.
|
||||
///
|
||||
/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
||||
/// Accesses out of bounds are safe so long as `clear_mask()` is called prior to dereferencing
|
||||
/// the iterator.
|
||||
///
|
||||
///
|
||||
|
||||
@@ -34,7 +34,7 @@
|
||||
\brief This defines a "fragment" iterator for visiting the fragments of a warp vector
|
||||
that participate in one warp-level mma operation.
|
||||
|
||||
Typically, this is used to access the scale/bias fragement of a warp-level mma operation.
|
||||
Typically, this is used to access the scale/bias fragment of a warp-level mma operation.
|
||||
The scale/bias vector is then partitioned into smaller fragments that can be fed into
|
||||
next warp-level mma operation.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user