diff --git a/changelog.md b/CHANGELOG.md similarity index 77% rename from changelog.md rename to CHANGELOG.md index d9ff1d5dd..c0606491e 100644 --- a/changelog.md +++ b/CHANGELOG.md @@ -1,6 +1,22 @@ # NVIDIA CUTLASS Changelog -## [1.0.1](https://github.com/NVIDIA/cutlass/releases/tag/v1.0.1) (2018-06-11) + +## 1.1.0 (2018-09-19) + * Turing Features + * WMMA GEMM targeting TensorCores - INT8, INT4, 1-bit + * Batched Strided GEMM + * Threadblock rasterization strategies + * Improved performance for adverse problem sizes and data layouts + * Extended CUTLASS Core comonents + * Tensor views support arbitrary matrix and tensor layouts + * Zip iterators for structuring multiple data streams + * Enhanced CUTLASS utilities + * Reference code for tensor operations in host and device code + * Added HostMatrix<> for simplified matrix creation + * Examples + * Basic GEMM, tensor views, CUTLASS utilities, batched GEMM, WMMA GEMM + +## 1.0.1 (2018-06-11) * Intra-threadblock reduction added for small threadblock tile sizes * sgemm_64x128x16, sgemm_128x128x16, sgemm_128x64x16, sgemm_128x32x16, sgemm_64x64x16, sgemm_64x32x16 diff --git a/CMakeLists.txt b/CMakeLists.txt index 5a53fae55..fdd51ae88 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -55,11 +55,21 @@ endif() find_package(CUDA) find_package(Doxygen QUIET) +################################################################################################### +# +# Configure CMake variables +# +################################################################################################### + +find_library(CUBLAS_LIBRARY cublas HINTS + ${CUDA_TOOLKIT_ROOT_DIR}/lib64 + ${CUDA_TOOLKIT_ROOT_DIR}/lib/x64) + # By default we want to build in Release mode to ensure that we're getting best performance if (NOT (CMAKE_BUILD_TYPE OR CONFIGURATION_TYPES)) set(CMAKE_BUILD_TYPE Release CACHE STRING "Choose build level" FORCE) # We do support Debug or Release builds - set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "Release") + set_property(CACHE CMAKE_BUILD_TYPE PROPERTY STRINGS "Debug" "RelWithDebInfo" "Release") endif() if(WIN32) @@ -68,27 +78,59 @@ if(WIN32) endif() if (WIN32) - # Enable more warnings and treat as errors - string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX") + # Enable more warnings and treat as errors + string(APPEND NVCC_FLAGS " -Xcompiler /W3 -Xcompiler /WX") - # Disable excess x86 floating point precision that can lead to results being labeled incorrectly - string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict") + # Disable warning on Unicode characters + string(APPEND NVCC_FLAGS " -Xcompiler /wd4819") - # Verbose option - if (${CUTLASS_NVCC_VERBOSE}) - string(APPEND NVCC_FLAGS " -v") - endif() + # Disable excess x86 floating point precision that can lead to results being labeled incorrectly + string(APPEND NVCC_FLAGS " -Xcompiler /fp:strict") + + # Verbose option + if (${CUTLASS_NVCC_VERBOSE}) + string(APPEND NVCC_FLAGS " -v") + endif() endif(WIN32) -# Configure CUDA options -set(CUTLASS_NVCC_ARCHS "50;60;61;70" CACHE STRING "The SM architectures to build code for.") -set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.") +set(CUTLASS_NVCC_ARCHS "50;60;61;70;75" CACHE STRING "The SM architectures to build code for.") +set(CUTLASS_NVCC_EMBED_CUBIN ON CACHE BOOL "Embed compiled CUDA kernel binaries into executables.") +set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.") +set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.") +# +# NOTE: running with asan and CUDA requires the following environment variable: +# +# ASAN_OPTIONS=protect_shadow_gap=0:replace_intrin=0:detect_leaks=0 +# +# without the above environment setting, an error like the following may be generated: +# +# *** Error: Could not detect active GPU device ID [out of memory] +# ... +# ==9149==ERROR: LeakSanitizer: detected memory leaks +# ... +# +if(ENABLE_ASAN) # https://github.com/google/sanitizers/wiki/AddressSanitizer + string(APPEND NVCC_FLAGS " --compiler-options -fsanitize=address --compiler-options -fno-omit-frame-pointer") + string(APPEND CMAKE_EXE_LINKER_FLAGS " -fsanitize=address") +endif() + +################################################################################################### +# +# Configure CUDA build options +# +################################################################################################### + +# Set NVCC arguments foreach(ARCH ${CUTLASS_NVCC_ARCHS}) - string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}") + if(CUTLASS_NVCC_EMBED_CUBIN) + string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=sm_${ARCH}") + endif() + if(CUTLASS_NVCC_EMBED_PTX) + string(APPEND NVCC_FLAGS " -gencode arch=compute_${ARCH},code=compute_${ARCH}") + endif() endforeach() - if (CUTLASS_NVCC_KEEP) string(APPEND NVCC_FLAGS " -keep") endif() @@ -99,11 +141,8 @@ else() string(APPEND NVCC_FLAGS " -lineinfo") endif() -if (UNIX) - string(APPEND NVCC_FLAGS " -Xcompiler -Wconversion") -endif() - string(APPEND NVCC_FLAGS_DEBUG " -g") +string(APPEND NVCC_FLAGS_RELWITHDEBINFO " -O3") string(APPEND NVCC_FLAGS_RELEASE " -O3") # define NDEBUG for release mode to disable assertions @@ -111,11 +150,13 @@ string(APPEND NVCC_FLAGS_RELEASE " -DNDEBUG") if (CUTLASS_NATIVE_CUDA) set(CMAKE_CUDA_FLAGS "${NVCC_FLAGS}") - set(CMAKE_CUDA_FLAGS_DEBUG "${NVCC_FLAGS_DEBUG}") set(CMAKE_CUDA_FLAGS_RELEASE "${NVCC_FLAGS_RELEASE}") + set(CMAKE_CUDA_FLAGS_RELWITHDEBINFO "${NVCC_FLAGS_RELWITHDEBINFO}") + set(CMAKE_CUDA_FLAGS_DEBUG "${NVCC_FLAGS_DEBUG}") else() set(CUDA_NVCC_FLAGS ${NVCC_FLAGS}) set(CUDA_NVCC_FLAGS_DEBUG ${NVCC_FLAGS_DEBUG}) + set(CUDA_NVCC_FLAGS_RELWITHDEBINFO ${NVCC_FLAGS_RELWITHDEBINFO}) set(CUDA_NVCC_FLAGS_RELEASE ${NVCC_FLAGS_RELEASE}) endif() @@ -128,6 +169,11 @@ file(GLOB CUTLASS_GEMM RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/gemm/*.h) file(GLOB CUTLASS_UTIL RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/util/*.h) file(GLOB CUTLASS_DEVICE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/device/*.h) file(GLOB CUTLASS_CORE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} cutlass/*.h) +################################################################################################### +# +# Define build targets +# +################################################################################################### source_group("cutlass\\gemm" FILES ${CUTLASS_GEMM}) source_group("cutlass\\util" FILES ${CUTLASS_UTIL}) @@ -156,9 +202,9 @@ add_custom_target(cutlass_ide SOURCES if (DOXYGEN_FOUND) # DOT is available. Enable graph generation in the documentation if (DOXYGEN_DOT_EXECUTABLE) - set(CUTLASS_ENABLE_DOXYGEN_DOT ON CACHE BOOL "Use dot to generate graphs in the doxygen documentation.") + set(CUTLASS_ENABLE_DOXYGEN_DOT ON CACHE BOOL "Use dot to generate graphs in the doxygen documentation.") else() - set(CUTLASS_ENABLE_DOXYGEN_DOT OFF CACHE BOOL "Use dot to generate graphs in the doxygen documentation." FORCE) + set(CUTLASS_ENABLE_DOXYGEN_DOT OFF CACHE BOOL "Use dot to generate graphs in the doxygen documentation." FORCE) endif() if (CUTLASS_ENABLE_DOXYGEN_DOT) @@ -177,6 +223,5 @@ if (DOXYGEN_FOUND) ) endif() - -#add_subdirectory(examples/gemm) add_subdirectory(tools) +add_subdirectory(examples) diff --git a/CUTLASS.md b/CUTLASS.md new file mode 100644 index 000000000..7dea0f372 --- /dev/null +++ b/CUTLASS.md @@ -0,0 +1,311 @@ +![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") + +# CUTLASS + +This document is intended to accompany the CUTLASS source code, to describe the interaction between +CUTLASS core components, and to identify their role in implementing GEMM computations efficiently in CUDA. + +1. [Design Patterns](#S-design-patterns) +2. [General Matrix Multiply](#S-general-matrix-multiply) +3. [Core Components](#S-core-components) +4. [Utilities](#S-utilities) + +# 1. Design Patterns + +CUTLASS strives to achieve the highest performance possible on NVIDIA GPUs while also offering a +flexible composition that an be easily applied to solve new problems related to Deep Learning and +linear algebra. Though we intend to make CUTLASS as simple and straightforward as possible, given +a tradeoff between simplicity and performance, CUTLASS chooses performance. Consequently, several +design patterns are necessary to yield a composable structure while also satisfying these performance +objectives. This section is intended to provide more detail. + +* [Sequencing and Nesting](#S-patterns-sequencing-nesting) +* [Tiles and Iterators](#S-patterns-tiles-iterators) +* [Host-side Params](#S-patterns-host-side-params) +* [Composable Shared Memory](#S-patterns-composable-shared-memory) + +## Sequencing and Nesting of Collective Primitives + +CUTLASS embodies a design paradigm exemplified by the [CUB library](https://nvlabs.github.io/cub/) for expressing collective operations. Objects expose an interface for a problem that is then decomposed into concurrent subtasks executed by cooperating threadblocks, warps, and threads. For example, a grid-level object may be constructed with base pointers to the start of a GEMM operation, add a threadblock-dependent offset to partition the problem, and then compute a per-threadblock GEMM. This in turn performs some operations as a collection of cooperating threads, while it may partition other parts of the task into warp-level subtasks. + +## Tiles and Iterators + +Efficient dense linear algebra computations emphasize data movement to match the execution of mathemtical operators to the flow of data. Consequently, CUTLASS defines a rich set of primitives for partitioning a tile of data among participating threads, warps, and threadblocks. CUTLASS applies the familiar iterator design pattern to provide an abstraction layer to (1.) access these tile objects and (2.) traverse a sequence of objects embedded in a higher level data structure. These subpartitions are typically defined by compile-time constants +specifying element type, size, and data layout. CUTLASS refers to subpartitions as _tiles_. + +_Iterators_ are familiar design patterns in C++ that provide an abstraction for accessing individual +elements in memory as well as traversing over a collection. GEMM kernels in CUTLASS depend on accessing +a sequence of tiles from global memory, from shared memory, and in registers. Consequently, _tile iterators_ +are prevalent throughout the CUTLASS implementation. + +The canonical CUTLASS tile iterator template is defined in [cutlass/tile_iterator.h](cutlass/tile_iterator.h). + +## Host-side Params structure + +Several CUTLASS template classes exhibit a pattern in which problem-specific internal state is known at kernel launch time and remains invariant throughout the execution of a kernel. For example, tile iterators compute several offsets based on the strides of the input tensor that is added to an internal pointer when loading the elements of a tile. These are computed from the tensor stride and never updated; the per-thread internal state consists only of the internal global memory pointer. + +CUTLASS can take advantage of this CUDA grid-invariant property by constructing the object in host code and passing a composed parameters structure to the kernel. This confers two benefits: (1.) invariant state is held in constant memory, and (2.) there is no overhead to compute the initial state by each thread. + +The design pattern in CUTLASS is for classes with nontrivial constructors to define `struct Params` as an inner class which contains grid-invariant state. These should define a constructor and an `initialize()` method. The `Params` structure should also include a data member corresponding to each data member in the parent class, so these too can be properly constructed in host code. The parent class should define a constructor which accepts `Params const &` as its first argument. + +For example, `cutlass::gemm::Gemm<>` should define `struct cutlass::gemm::Gemm::Params`. The latter should define data members for each data member in `cutlass::gemm::Gemm<>`. + + +## Composable shared memory allocation + +Shared memory requires explicit effort by the programmer to allocate and de-allocate. CUTLASS follows the paradigm introduced by [CUB](https://nvlabs.github.io/cub/) to define composed structures for storing data intended to be held in shared memory. Any object requiring shared memory storage for itself or its data members should define a child structure called SharedStorage. This holds data needed by the class and also instantiates SharedStorage objects for each data member. + +To be consistent, this pattern defines a convention in which classes define internal shared memory storage requirements. Classes should consider all SharedStorage structures to be opaque other than their own child class. When the lifetimes of child objects are known to be non-overlapping, unions may be used to alias multiple SharedStorage objects to the same shared memory region and reduce overall SMEM capacity. + +## Loop Unrolling + +CUTLASS requires tiles of data to be stored in registers for high-bandwidth access. Simultaneously, high-throughput math instructions +must be issued concurrently with memory instructions to hide latency with relatively few concurrent threads. These objectives are +achieved by unrolling loops whose iteration counts are known at compile time. + +Consequently, most loops within the CUTLASS GEMM implementation are specified by constant values and template arguments. The CUDA compiler +is able to unroll the loop bodies, map array elements to registers, and construct an efficient instruction schedule. + +## Templates + +CUDA C++ templates and modern generic programming techniques enable CUTLASS device code to span a large design space. + +This design space includes: +* Mixed precision arithmetic and data storage +* Kernels specialized for layout and problem size +* Support for kernel fusion + +Moreover, templates provided a structured approach to collecting compile-time constants such as tile dimensions. These +must be template arguments to target static array allocation and take advantage of loop unrolling, constant folding, +and function inlining. + +# 2. General Matrix Multiply + +The following figure illustrates the hierarchical GEMM computation embodied by CUTLASS. Each stage depicts a nested level of tiling which corresponds to a layer of concurrency within the CUDA execution model and to a level within the memory hierarchy, becoming increasingly finer moving left to right. + +![ALT](/media/images/gemm-structural-components.png "CUTLASS GEMM Structural Components") + +## Threadblock-level GEMM + +The CUTLASS GEMM kernel partitions the _C_ matrix into a 2D tiling of threadblocks. +Each threadblock computes a matrix product whose outer dimensions _M_ and _N_ are compile-time constants. The +GEMM's _K_ dimension is partitioned into tiles and iterated over by the GEMM _mainloop_. The shape of the matrix +multiply operation performed by each iteration of the mainloop is referred to as _OutputTile_. + +The threadblock loads a sequence of tiles from global memory and stores this data to shared memory. The iterative +access and traversal of tiles in global memory are performed by a _TileLoadIterator_, and storing to a circular +buffer in shared memory is performed by a _GlobalLoadIterator_. + +**[Global Load Stream](cutlass/gemm/gemm_global_stream.h)** manages loading of the threadblock-scope multiplicands to the GEMM kernel. It owns an iterator into global memory for loading tiles of data, a TensorAllocation in shared memory to hold the resulting tile, and an iterator for writing the tile into this allocation. A transformer exists to optionally transform the data as it is loaded which may of use to perform type conversion or, in the case of int8 GEMM, transpose 4x4 tiles held in registers. + +The Global Load Stream template contains members defined by the following templates: + +* [GemmGlobalIteratorAb](cutlass/gemm/gemm_global_tile.h) +* [Transformer](cutlass/convert.h) +* [GemmSharedStoreTileAb](cutlass/gemm/gemm_shared_tile.h) + +## Warp-level GEMM + +The threadblock's _OutputTile_ is partitioned among the warps, and each computes a warp-level matrix product. +Data is loaded from shared memory into registers, and math instructions are dispatched to CUDA Cores or Tensor Cores. + +[**Shared Load Stream**](cutlass/gemm/gemm_shared_stream.h) manages loading of warp-level multiplicands from shared memory into registers. This owns an iterator for fetching data and the destination fragments for holding the results. + +* [GemmSharedLoadTile{A,B}](cutlass/gemm/gemm_shared_tile.h) + +**Matrix Multiply** computes a matrix product operation on data held in registers. Specializations exist for thread-level instructions such as single-precision fused multiply-add as well as warp-level matrix operations targeting TensorCores. + +* [WMMA Multiply Add](cutlass/gemm/wmma_gemm_multiply_add.h) + +## Thread-level GEMM + +SGEMM, IGEMM, HGEMM, and DGEMM are computed by SIMT math instructions issued by thread-level matrix multiply +procedures. + +* [ThreadMultiplyAdd](cutlass/gemm/thread_multiply_add.h) +* [IGEMM specialization](cutlass/gemm/igemm_multiply_add.h) +* [HGEMM specialization](cutlass/gemm/hgemm_multiply_add.h) + +## Epilogue + +The [**epilogue**](cutlass/gemm/gemm_epilogue.h) iteratively selects a subset of accumulator elements held by a warp, writes them to shared memory, and loads them by different threads such that a threadblock-scoped tile store operation will make contiguous, striped accesses to global memory. Thus, the flow of data utilizes the following components: + +1. [Transformer](cutlass/convert.h) for converting the data types of accumulator elements +2. [GemmSharedStoreTileD](cutlass/gemm/gemm_shared_tile.h) to store to shared memory specialized to the accumulator layout. +3. [GemmSharedLoadTileD](cutlass/gemm/gemm_shared_tile.h) to load the data from shared memory. +4. [GemmGlobalIteratorC](cutlass/gemm/gemm_global_tile.h) to load a tile from global memory. +5. A [functor](cutlass/gemm/linear_scaling.h) to compute an element-wise operation on the matrix product and source data (such as alpha*AB+beta*C). +6. [GemmGlobalIteratorD](cutlass/gemm/gemm_global_tile.h) to write the output to global memory. + +## GEMM Traits + +[**cutlass::gemm::GemmTraits**](cutlass/gemm/gemm_traits.h) collects the structural properties of a complete GEMM computation into a single template class. As a result, the Traits classes encapsulate the the iterators and transformers for all supported GEMM operands and layouts. Low-level details needed by Traits (such as scalar types for operands, thread-block tile size, number of scalar elements per memory access within each phase, number of stages in shared memory, as well as other implementation-specific properties of the GEMM computation) are specified in class [**cutlass::gemm::GemmConfig**](cutlass/gemm/gemm_config.h). + + +# 3. Core Components + +CUTLASS GEMM kernels are implemented by a set of Core components for interacting with mathematical tensor and matrix +objects as well as constructing efficient CUDA kernels. + +* [Tensor views](#S-core-tensor-views) +* [Shape](#S-core-shape) +* [Tile structure](#S-core-tile-structure) +* [Fragment](#S-core-fragment) +* [Predicate vector](#S-core-predicate-vector) + +## Tensor View + +Matrices and tensors are typically represented as n-D arrays held in linear memory with a single base pointer and a stride vector. Element _i_ of the stride vector indicates the offset in linear memory between consecutive elements in dimension i. Consequently, the linear offset for an arbitrary element specified as an n-tuple may be computed as the dot product of the coordinate and the stride vector. + +CUTLASS provides abstractions for interacting with multidimension tensors in device memory. +Consequently, we define a hierarchy of pointer-like types for referencing tensors. + +`T *` - raw pointer to elements of type T + +`cutlass::TensorRef` - reference to a tensor of elements of type T and given rank. Includes a mapping function and associated stride vector for accessing elements in linear memory. + +`cutlass::TensorView` - extends `TensorRef<>` by adding bounds information. This is a complete mathematical object which may be used as the argument to CUTLASS functions. + +The above provide an identity maping of a logical index space to linear memory. An element +at logical coordinate X has an offset computed as follows: +``` +offset = dot(X, stride) +``` +where `dot()` computes the inner product of X and a vector of "strides." + +CUTLASS 1.1 introduces a mapping function and an additional "storage rank" to offer a flexible way to +map the logical index space of the tensor to memory. The mapping function maps a coordinate +of rank _R_ to an index space of rank _S_. The linear offset is computed as: +``` +offset = dot( MapFunc(X), stride ) +``` +where stride is a vector of rank _S_. + +CUTLASS kernels make extensive use of vectorization of memory accesses for efficiency and +correctness. Consequently, we enforce a constraint on the strides used by mapping functions +such that: + +1. The "fastest-changing" stride is always 1 thereby mandating that consecutive elements in + that rank are consecutive in linear memory. + +2. The fastest changing rank is always last in the stride vector and not explicitly stored. + +Thus, the stride vector used by mapping functions has length of one fewer than the rank of the +storage tensor. These constraints are consistent with the BLAS interface of passing matrices as +a tuple consisting of a pointer and a "leading dimension." In fact, these are rank=2 tensors +whose fastest changing dimension is 1, and only the strided dimension is explicitly represented. + +A typical mapping function might simply map the rows and columns of a matrix, a rank=2 tensor, +to linear memory such that (1.) elements in the same column are consecutive in memory +(column-major), or (2.) elements in the same row are consecutive (row-major). These can be +accomplished by two different mapping functions whose stride vector is length=2. The first +element is the "leading dimension." + +The requirement that the fastest-changing stride always be of unit size need not be a limitation. +To implement "sparse" computations or matrix operations in which matrix elements have arbitrary +stride along both row and column, define a mapping function whose storage rank is 3. This permits +two elements of the stride vector to have a non-unit value. + +`cutlass::TensorView<>` extends this concept by including a size vector to specify the bounds of +the index space. The value of each coordinate in the size vector defines the half-open range of +indices whose smallest value is zero. + +## Shape + +To avoid complicated template metaprogramming, CUTLASS targets fixed compile-time tile sizes specified +by a four-dimensional template `cutlass::Shape<>`. This defines the following dimensions, mirroring +the NHWC tensor format used for convolution in Deep Learning frameworks. + +- `D`: depth of tensor +- `H`: first strided dimension +- `W`: contiguous sequence of tensor elements +- `C`: number of channels, usually used for vectorized access + +Template specializations of `Shape` appear as arguments to numerous dependent template classes which +must specify compile-time constant tile sizes. + +## Tile Structure + +Tiled structures express an arrangement of data in memory as well as a logical mapping of concurrent CUDA +threads to the problem space. For example, the CUTLASS GEMM + +Tiled structures can be defined using the `cutlass::TileTraits<>` concept which defines the following +members. Collectively, these members offer a flexible way to define a 4-D subpartition of an integer +lattice, partition its elements among a collection of threads, and map each unique thread ID to a unique +offset. + +- _Tile_ (concept `Shape<>`) - describes the dimensions of the tile in terms of scalar elements +- _Delta_ (concept `Shape<>`) - describes the distance along each logical dimension between items +- _Iterations_ (concept `Shape<>`) - describes the number of items along each logical dimension +- _ThreadOffset_ (concept _functor_) - implements `Coord<4> operator()() const` to determine a thread's + initial offset in the logical 4-D coordinate space + +The following figure illustrates the CUTLASS tile structure. The overall shape, 16-by-16, is partitioned into +vectors of length two among 32 threads. The elements stored by thread 9 are highlighted. + +CUTLASS tile structure + +The `cutlass::TileTraits<>` definition that describes this arrangement may be defined as follows: + +``` +struct ExampleTileTraits { + + /// Overall shape of tile + typedef Shape<1, 16, 16, 1> Tile; + + /// Distance along each dimension of accesses + typedef Shape<1, 4, 1, 1> Delta; + + /// Number of memory accesses performed by each thread + typedef Shape<1, 4, 1, 1> Iterations; + + /// Offset function - maps each thread to a unique starting offset within the 4D tile + struct ThreadOffset { + + CUTLASS_DEVICE Coord<4> operator()() const { + + typdef Shape<1, 16, 8, 2> Vectorized; + + return make_Coord( + 0, // depth "D" dimension + threadIdx.x / Vectorized::kW, // horisontal "H" dimension - first strided dimension + threadIdx.x % Vectorized::kW, // vertical "W" dimension - contiguous dimension + 0 + ); + } + }; +}; +``` + +## Tile Iterator + +The iterator design pattern provides an abstraction for accessing the items in a collection in sequence. Basic +operators defined by iterators consist of accessing an item - either a load or store - followed by traversal to +the next item in sequence. + +CUTLASS tile access and traversal + +To offer a generic solution that spans numerous data types and layouts, CUTLASS defines the _TileIterator_ concept. +This concept provides access to a sequence of _tiles_ embedded in a tensor in addressable memory. + +The canonical CUTLASS tile iterator template is defined in [cutlass/tile_iterator.h](cutlass/tile_iterator.h). + +## Fragment + +A fragment is analogous to `std::array<>` in that it is a constant-sized array of elements. Typically backed by storage in the SM's register file, CUTLASS `Fragment<>` objects are used to store tiles. For threadblock- and warp-scope operations, the contents of these tiles are distributed across the partipcipating threads. In such cases, a thread's `Fragment<>` contains the part of the tile held by that thread. + +## Predicate Vector + +SIMT architectures utilize predicated execution in place of control flow when conditional code sequences are fairly short, on the order of a few machine instructions. While CUDA C++ does not include constructs at the language level for predication, PTX makes this explicit, and compilation to SASS is assumed to aggressively utilize predication. Typical applications are to initialize a sequence of bits used to mask memory operations and use these bits as predicates guarding memory load and store instructions. + +CUTLASS provides `PredicateVector` defined in [cutlass/predicate_vector.h](cutlass/predicate_vector.h) to manage a statically-sized bit vector, store them into general purpose registers, and efficiently access them in sequence. By storing four predicates per byte in hardware registers, the CUDA compiler is able to issue specialized instructions to achieve very efficient unpacking. + + +# 4. Utilities + +CUTLASS implements efficient matrix multiply computations on GPUs. It is accompanied by an extensive utility +framework offering features such as: + +* [cutlass::half_t](tools/util/half.h) - a host-side half-precision type +* Components for allocating and initializing [host-side and device-side tensors](tools/util/host_tensor.h) usable by CUTLASS +* Reference implementations of [GEMM](tools/util/reference/host/gemm.h) and [element-wise operations](tools/util/reference/host/tensor_elementwise.h) diff --git a/Doxyfile b/Doxyfile index 51cec529b..1d96f3770 100644 --- a/Doxyfile +++ b/Doxyfile @@ -58,7 +58,7 @@ PROJECT_LOGO = # entered, it will be relative to the location where doxygen was started. If # left blank the current directory will be used. -OUTPUT_DIRECTORY = docs +OUTPUT_DIRECTORY = doxygen # If the CREATE_SUBDIRS tag is set to YES, then doxygen will create 4096 sub- # directories (in 2 levels) under the output directory of each output format and diff --git a/README.md b/README.md index 56473a286..c53a42f4b 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,10 @@ ![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 1.0 +# CUTLASS 1.1 -_CUTLASS 1.0.1 - June 2018_ +_CUTLASS 1.1.0 - September 2018_ -CUTLASS 1.0 is a collection of CUDA C++ template abstractions for implementing +CUTLASS 1.1 is a collection of CUDA C++ template abstractions for implementing high-performance matrix-multiplication (GEMM) at all levels and scales within CUDA. It incorporates strategies for hierarchical decomposition and data movement similar to those used to implement cuBLAS. CUTLASS decomposes these "moving parts" into @@ -22,14 +22,27 @@ point (FP64) types. Furthermore, CUTLASS demonstrates CUDA's WMMA API for targe the programmable, high-throughput _Tensor Cores_ provided by NVIDIA's Volta architecture and beyond. -CUTLASS 1.0 has changed substantially from our preview release described in -the [CUTLASS Parallel For All](https://devblogs.nvidia.com/parallelforall/cutlass-linear-algebra-cuda) -post. We have decomposed the structure of the GEMM computation into deeper, structured -primitives for loading data, computing predicate masks, streaming data at each level of -the GEMM hierarchy, and updating the output matrix. +CUTLASS 1.1 is described in the [CUTLASS Documentation](CUTLASS.md) and the accompanying +[Doxygen documentation](https://nvidia.github.io/cutlass). +We describe the structure of an efficient GEMM in our talk at the +[GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf). -CUTLASS 1.0 is described in the [Doxygen documentation](https://nvidia.github.io/cutlass) -and our talk at the [GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf). +# What's New in CUTLASS 1.1 + +* [CUTLASS Documentation](CUTLASS.md) +* [Examples](examples/) + * Basic GEMM, tensor views, CUTLASS utilities, batched GEMM, WMMA GEMM +* Turing Features + * [WMMA GEMM targeting TensorCores](tools/test/unit/gemm/wmma_integer_gemm.cu) - INT8, INT4, 1-bit +* [Batched Strided GEMM](tools/test/unit/gemm/batched_strided_sgemm_128x128x8.cu) +* [Threadblock rasterization strategies](tools/test/unit/gemm/sgemm_threadblock_swizzle_nt.cu) + * Improved performance for adverse problem sizes and data layouts +* Extended CUTLASS Core comonents + * Tensor views support arbitrary matrix and tensor layouts + * Zip iterators for structuring multiple data streams +* Enhanced CUTLASS utilities + * [Reference implementations](tools/util/reference) for tensor operations in [host](tools/util/reference/host) and [device](tools/util/reference/device) code + * Added `HostMatrix<>` for simplified matrix creation # Performance @@ -39,11 +52,11 @@ CUTLASS primitives are very efficient. When used to construct device-wide GEMM they exhibit performance comparable to cuBLAS for scalar GEMM computations. The above figure shows CUTLASS performance relative to cuBLAS for large matrix dimensions (M=10240, N=K=4096) running on an NVIDIA Titan V GPU -when compiled with CUDA 9.2. +when compiled with CUDA 10.0. # Compatibility -CUTLASS requires CUDA 9 and performs best with [CUDA 9.2 Toolkit](ttps://developer.nvidia.com/cuda-toolkit) or later. +CUTLASS requires CUDA 9 but performs best with [CUDA 10.0 Toolkit](ttps://developer.nvidia.com/cuda-toolkit) or later. |**Operating System** | **Compiler** | |-----------------|----------| @@ -63,7 +76,7 @@ any Maxwell-, Pascal-, or Volta-architecture NVIDIA GPU. |NVIDIA Tesla P100| |NVIDIA Tesla V100| |NVIDIA TitanV| - +|NVIDIA GeForce RTX 2080 TI, 2080, 2070| # Building CUTLASS @@ -79,7 +92,7 @@ $ git submodule update --init --recursive ``` CUTLASS can be build with CMake starting version 3.10. By default CUTLASS will build kernels -for CUDA architecture versions 5.0, 6.0, 6.1 and 7.0. To reduce compile time you can specify +for CUDA architecture versions 5.0, 6.0, 6.1, 7.0 and 7.5. To reduce compile time you can specify the architectures to build CUTLASS for by changing the CMake configuration setting `CUTLASS_NVCC_ARCHS`. @@ -107,13 +120,12 @@ $ ./tools/test/unit/cutlass_unit_test ... ... [----------] Global test environment tear-down -[==========] 481 tests from 24 test cases ran. (5954 ms total) -[ PASSED ] 481 tests. +[==========] 946 tests from 57 test cases ran. (10812 ms total) +[ PASSED ] 946 tests. ``` All tests should pass, though the exact number of tests may vary over time. - # Project Structure CUTLASS is arranged as a header-only library with several example test programs @@ -128,28 +140,41 @@ templates in the cutlass/gemm directory. ``` cutlass/ - gemm/ - util/ - + gemm/ + util/ + ``` Several tools and test programs are also distributed with the CUTLASS library. They are contained in the following directories. ``` +examples/ + 00_basic_gemm/ + 01_tensor_view/ + 02_cutlass_utilities/ + 03_batched_gemm/ + 04_tile_iterator/ + 05_wmma_gemm/ tools/ - test/ - unit/ - core/ - gemm/ - perf/ - util/ - + test/ + unit/ + core/ + gemm/ + perf/ + util/ + reference/ + device/ + host/ + ``` The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate basic usage of Core API components and complete tests of the CUTLASS GEMM computations. +The `tools/util` directory contains CUTLASS utilities including reference implementations of GEMM and +several element-wise tensor operations. + # Performance Profiling The `test/perf/` directory contains a command-line utility for launching each of the GEMM kernels. diff --git a/clang-format.sh b/clang-format.sh deleted file mode 100755 index b2570d914..000000000 --- a/clang-format.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/bin/bash - -set -e - -function formatFiles { - for f in `find "$1" -type f -name "*.$2"` ; do - COMMAND="clang-format -i $f" - echo $COMMAND - $COMMAND - done -} - -formatFiles "cutlass" "h" -formatFiles "tools/test" "h" -formatFiles "tools/test" "cpp" -formatFiles "tools/util" "h" - diff --git a/cutlass/convert.h b/cutlass/convert.h index 933d68a82..b4d0f8edd 100644 --- a/cutlass/convert.h +++ b/cutlass/convert.h @@ -28,7 +28,7 @@ */ #pragma once -#include +#include "cutlass/fragment.h" namespace cutlass { diff --git a/cutlass/coord.h b/cutlass/coord.h index 431c9bf1a..625a22723 100644 --- a/cutlass/coord.h +++ b/cutlass/coord.h @@ -28,7 +28,8 @@ #pragma once -#include +#include "cutlass/cutlass.h" +#include "cutlass/util/platform.h" namespace cutlass { @@ -44,20 +45,27 @@ struct Identity { //////////////////////////////////////////////////////////////////////////////////////////////////// /// Statically-sized array specifying Coords within a tensor -template +template struct Coord { // // Type and constant definitions // - static int const N = N_; + /// Number of elements in Coord + static int const kRank = Rank_; + + /// Number of elements in Coord, aliased for compatibility + static int const N = Rank_; + + /// Index type used to store elements + typedef Index_ Index; // // Data members // /// Indices - int idx[N]; + Index idx[kRank]; // // Methods @@ -65,25 +73,72 @@ struct Coord { /// Default ctor initializes uniformly CUTLASS_HOST_DEVICE - Coord(int value = 0) { - for (int i = 0; i < N; ++i) { + Coord(Index value = 0) { + for (int i = 0; i < kRank; ++i) { idx[i] = value; } } /// Constructs from an array of integers CUTLASS_HOST_DEVICE - Coord(int _idx[]) { - for (int i = 0; i < N; ++i) { + Coord(Index _idx[]) { + for (int i = 0; i < kRank; ++i) { idx[i] = _idx[i]; } } + /// Constructs from an array of integers + CUTLASS_HOST_DEVICE + Coord(Coord const &coord) { + for (int i = 0; i < kRank; ++i) { + idx[i] = coord[i]; + } + } + + /// Returns a slice of the Coord which may be larger or smaller in rank + /// than this. + template + CUTLASS_HOST_DEVICE + Coord slice(int start = 0, Index identity = 0) const { + Coord result; + for (int i = 0; i < Slice; ++i) { + if (i + start < kRank) { + slice[i] = idx[i + start]; + } + else { + slice[i] = identity; + } + } + return result; + } + + /// Returns true if Coord is non-zero. + CUTLASS_HOST_DEVICE + operator bool() const { + for (int i = 0; i < kRank; ++i) { + if (idx[i]) { + return true; + } + } + return false; + } + + /// Returns true if Coord is uniformly zero. + CUTLASS_HOST_DEVICE + bool operator!() const { + for (int i = 0; i < kRank; ++i) { + if (idx[i]) { + return false; + } + } + return true; + } + /// Element-wise addition CUTLASS_HOST_DEVICE Coord operator+(Coord const& b) const { Coord c; - for (int i = 0; i < N; ++i) { + for (int i = 0; i < kRank; ++i) { c.idx[i] = idx[i] + b.idx[i]; } return c; @@ -93,7 +148,7 @@ struct Coord { CUTLASS_HOST_DEVICE Coord operator-(Coord const& b) const { Coord c; - for (int i = 0; i < N; ++i) { + for (int i = 0; i < kRank; ++i) { c.idx[i] = idx[i] - b.idx[i]; } return c; @@ -103,7 +158,7 @@ struct Coord { CUTLASS_HOST_DEVICE Coord operator*(Coord const& b) const { Coord c; - for (int i = 0; i < N; ++i) { + for (int i = 0; i < kRank; ++i) { c.idx[i] = idx[i] * b.idx[i]; } return c; @@ -113,7 +168,7 @@ struct Coord { CUTLASS_HOST_DEVICE Coord operator/(Coord const& b) const { Coord c; - for (int i = 0; i < N; ++i) { + for (int i = 0; i < kRank; ++i) { c.idx[i] = idx[i] / b.idx[i]; } return c; @@ -122,7 +177,7 @@ struct Coord { /// In-place addition CUTLASS_HOST_DEVICE Coord& operator+=(Coord const& b) { - for (int i = 0; i < N; ++i) { + for (int i = 0; i < kRank; ++i) { idx[i] += b.idx[i]; } return *this; @@ -131,7 +186,7 @@ struct Coord { /// In-place subtraction CUTLASS_HOST_DEVICE Coord& operator-=(Coord const& b) { - for (int i = 0; i < N; ++i) { + for (int i = 0; i < kRank; ++i) { idx[i] -= b.idx[i]; } return *this; @@ -140,7 +195,7 @@ struct Coord { /// In-place multiplication CUTLASS_HOST_DEVICE Coord& operator*=(Coord const& b) { - for (int i = 0; i < N; ++i) { + for (int i = 0; i < kRank; ++i) { idx[i] *= b.idx[i]; } return *this; @@ -149,22 +204,22 @@ struct Coord { /// In-place division CUTLASS_HOST_DEVICE Coord& operator/=(Coord const& b) { - for (int i = 0; i < N; ++i) { + for (int i = 0; i < kRank; ++i) { idx[i] /= b.idx[i]; } return *this; } /// Member access operator - CUTLASS_HOST_DEVICE int& operator[](int dim) { return idx[dim]; } + CUTLASS_HOST_DEVICE Index& operator[](int dim) { return idx[dim]; } /// Member access operator - CUTLASS_HOST_DEVICE int const& operator[](int dim) const { return idx[dim]; } + CUTLASS_HOST_DEVICE Index const& operator[](int dim) const { return idx[dim]; } /// Computes the dot product of two Coord instances template CUTLASS_HOST_DEVICE T dot(Coord const& b, T sum) const { - for (int i = 0; i < N; ++i) { + for (int i = 0; i < kRank; ++i) { sum += idx[i] * b.idx[i]; } return sum; @@ -174,7 +229,7 @@ struct Coord { template CUTLASS_HOST_DEVICE T dot(Coord const& b) const { T sum = T(0); - for (int i = 0; i < N; ++i) { + for (int i = 0; i < kRank; ++i) { sum += idx[i] * b.idx[i]; } return sum; @@ -182,29 +237,29 @@ struct Coord { /// Gets the index of a given Coord element template - CUTLASS_HOST_DEVICE int& at() { + CUTLASS_HOST_DEVICE Index& at() { return idx[Dim]; } /// Access via index; may limit unrolling potential CUTLASS_HOST_DEVICE - int& at(int dim) { return idx[dim]; } + Index& at(int dim) { return idx[dim]; } /// Gets the index of a given Coord element template - CUTLASS_HOST_DEVICE int const& at() const { + CUTLASS_HOST_DEVICE Index const& at() const { return idx[Dim]; } /// Access via index; may limit unrolling potential CUTLASS_HOST_DEVICE - int const& at(int dim) const { return idx[dim]; } + Index const& at(int dim) const { return idx[dim]; } /// Determines if two Coord<> objects are equal CUTLASS_HOST_DEVICE - bool operator==(Coord const& b) const { + bool operator==(Coord const& b) const { bool equal = true; - for (int i = 0; equal && i < N; ++i) { + for (int i = 0; equal && i < kRank; ++i) { equal = (idx[i] == b.idx[i]); } return equal; @@ -212,12 +267,12 @@ struct Coord { /// Not equal CUTLASS_HOST_DEVICE - bool operator!=(Coord const& b) const { return !(*this == b); } + bool operator!=(Coord const& b) const { return !(*this == b); } /// Clamps a coordinate to a range specified by maximum and minimum values CUTLASS_HOST_DEVICE - Coord& clamp(Coord const& max, Coord const& min = Coord()) { - for (int i = 0; i < N; ++i) { + Coord& clamp(Coord const& max, Coord const& min = Coord()) { + for (int i = 0; i < kRank; ++i) { idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]); } return *this; @@ -225,13 +280,35 @@ struct Coord { /// Returns the product of all elements CUTLASS_HOST_DEVICE - int count() const { - int product = idx[0]; - for (int i = 1; i < N; ++i) { + Index count() const { + Index product = idx[0]; + for (int i = 1; i < kRank; ++i) { product *= idx[i]; } return product; } + + /// Less than operator + CUTLASS_HOST_DEVICE + bool operator<(Coord const &b) const { + for (int i = 0; i < kRank; ++i) { + if (!(idx[i] < b[i])) { + return false; + } + } + return true; + } + + /// Less than or equals operator + CUTLASS_HOST_DEVICE + bool operator<=(Coord const &b) const { + for (int i = 0; i < kRank; ++i) { + if (!(idx[i] <= b[i])) { + return false; + } + } + return true; + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -266,21 +343,10 @@ Coord<4> make_Coord(int _0, int _1, int _2, int _3) { //////////////////////////////////////////////////////////////////////////////////////////////////// -/// Getter -CUTLASS_HOST_DEVICE -Coord<2> get_Coord_hw(Coord<3> const& coord) { return make_Coord(coord[1], coord[2]); } - -/// Getter -CUTLASS_HOST_DEVICE -Coord<2> get_Coord_hw(Coord<4> const& coord) { return make_Coord(coord[1], coord[2]); } - -/// Getter -CUTLASS_HOST_DEVICE -Coord<3> get_Coord_hwc(Coord<4> const& coord) { return make_Coord(coord[1], coord[2], coord[3]); } - -/// Getter -CUTLASS_HOST_DEVICE -Coord<3> get_Coord_dhw(Coord<4> const& coord) { return make_Coord(coord[0], coord[1], coord[2]); } +template +CUTLASS_HOST_DEVICE Coord<3> make_Coord_from_shape() { + return make_Coord(Shape_::kD, Shape_::kH, Shape_::kW); +} //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass/core_io.h b/cutlass/core_io.h index cceea4c06..849a7613f 100644 --- a/cutlass/core_io.h +++ b/cutlass/core_io.h @@ -22,8 +22,6 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -#pragma once - /*! \file \brief Helpers for printing cutlass/core objects */ @@ -33,12 +31,96 @@ #include #include -#include +#include "cutlass/coord.h" +#include "cutlass/vector.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// template -std::ostream& operator<<(std::ostream& out, cutlass::Coord const& coord) { +std::ostream& operator<<(std::ostream& out, Coord const& coord) { for (int i = 0; i < Rank; ++i) { out << (i ? ", " : "") << coord.idx[i]; } return out; } + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to enable formatted printing of CUTLASS scalar types to an ostream +template +struct ScalarIO { + + /// Value to print + T value; + + /// Default ctor + ScalarIO() { } + + /// Constructs from a value + ScalarIO(T value): value(value) {} +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Default printing to ostream +template +inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { + return out << scalar.value; +} + +/// Printing to ostream of int8_t as integer rather than character +template <> +inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { + return out << int(scalar.value); +} + +/// Printing to ostream of uint8_t as integer rather than character +template <> +inline std::ostream &operator<<(std::ostream &out, ScalarIO const &scalar) { + return out << unsigned(scalar.value); +} + +/// Printing to ostream of vector of 1b elements +template <> +inline std::ostream &operator<<( + std::ostream &out, + ScalarIO > const &scalar) { + + for (int i = 0; i < 32; i++) { + out << int(scalar.value[i]); + out << ((i != 31) ? ", " : ""); + } + return out; +} + +/// Printing to ostream of vector of 4b signed integer elements +template <> +inline std::ostream &operator<<( + std::ostream &out, + ScalarIO > const &scalar) { + + for (int i = 0; i < 8; i++) { + out << int(scalar.value[i]); + out << ((i != 7) ? ", " : ""); + } + return out; +} + +/// Printing to ostream of vector of 4b unsigned integer elements +template <> +inline std::ostream &operator<<( + std::ostream &out, + ScalarIO > const &scalar) { + + for (int i = 0; i < 8; i++) { + out << unsigned(scalar.value[i]); + out << ((i != 7) ? ", " : ""); + } + return out; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/cutlass/cutlass.h b/cutlass/cutlass.h index 19600ec8f..15ea83c01 100644 --- a/cutlass/cutlass.h +++ b/cutlass/cutlass.h @@ -32,8 +32,8 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// #define CUTLASS_MAJOR 1 -#define CUTLASS_MINOR 0 -#define CUTLASS_PATCH 1 +#define CUTLASS_MINOR 1 +#define CUTLASS_PATCH 0 #define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH) #ifdef __NVCC__ @@ -47,7 +47,9 @@ // CUTLASS_DEVICE is an error if not compiling device code #endif -// CUTLASS_PRAGMA_UNROLL inserts a CUTLASS_PRAGMA_UNROLL if supported by the compiler +#define CUTLASS_ASSERT(x) assert(x) + +// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler. #if defined(__CUDA_ARCH__) #if defined(_MSC_VER) #define CUTLASS_PRAGMA_UNROLL __pragma("unroll") @@ -61,7 +63,22 @@ #define CUTLASS_PRAGMA_NO_UNROLL #endif -#define CUTLASS_ASSERT(x) assert(x) +#define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL + +// A small helper class to dump a type at compile time +// Usage:: DumpType::Class +template +struct DebugType {}; + +template +void DebugTypeFunc(T const& t) { + T::t; +} + +// A small helper class to dump a compile time constant at compile time +// Usage: DumpValue::kConstant +template +struct DebugValue {}; namespace cutlass { diff --git a/cutlass/fragment.h b/cutlass/fragment.h index 886b11405..6a93d779c 100644 --- a/cutlass/fragment.h +++ b/cutlass/fragment.h @@ -29,9 +29,9 @@ #pragma once #include -#include -#include -#include +#include "cutlass/shape.h" +#include "cutlass/util/cutlass_math.h" +#include "cutlass/vector.h" namespace cutlass { @@ -72,7 +72,7 @@ provides access to element at (d, h, w, c) //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct StorageType { typedef uint64_t Type; }; @@ -108,9 +108,11 @@ struct Fragment : public AlignedStruct { typedef Element_ Element; /// The number of elements. static int const kElements = kElements_; + /// Alignment + static int const kAlignment = kAlignment_; /// Clear a fragment. - CUTLASS_DEVICE void clear() { + CUTLASS_HOST_DEVICE void clear() { // Avoid element-wise access for sub 32b element type if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) { uint64_t* ptr = reinterpret_cast(storage); @@ -135,14 +137,10 @@ struct Fragment : public AlignedStruct { } /// The accessor. - CUTLASS_DEVICE Element& operator[](int i) { - assert(i < kElements_); - return reinterpret_cast(storage)[i]; - } + CUTLASS_HOST_DEVICE Element& operator[](int i) { return reinterpret_cast(storage)[i]; } /// The accessor. - CUTLASS_DEVICE Element const& operator[](int i) const { - assert(i < kElements_); + CUTLASS_HOST_DEVICE Element const& operator[](int i) const { return reinterpret_cast(storage)[i]; } @@ -188,35 +186,35 @@ struct FragmentIterator { /// Ctor. template - CUTLASS_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0) + CUTLASS_HOST_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0) : pointer(reinterpret_cast(&fragment[offset])) { static_assert(OtherFragment_::kElements >= Fragment::kElements, ""); } /// The accessor. - CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const { + CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const { int const imm = ComputeOffsetFromStrides::get(d, h, w, c); return reinterpret_cast(pointer[imm]); } /// The accessor. - CUTLASS_DEVICE AccessType& at(int d, int h, int w, int c = 0) { + CUTLASS_HOST_DEVICE AccessType& at(int d, int h, int w, int c = 0) { int const imm = ComputeOffsetFromStrides::get(d, h, w, c); return reinterpret_cast(pointer[imm]); } /// The accessor. - CUTLASS_DEVICE AccessType const& operator[](int i) const { + CUTLASS_HOST_DEVICE AccessType const& operator[](int i) const { return reinterpret_cast(pointer[i * kElementsPerAccess]); } /// The accessor. - CUTLASS_DEVICE AccessType& operator[](int i) { + CUTLASS_HOST_DEVICE AccessType& operator[](int i) { return reinterpret_cast(pointer[i * kElementsPerAccess]); } /// Is the iterator valid? - CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; } + CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; } /// The pointer. Element* pointer; @@ -246,28 +244,28 @@ struct FragmentConstIterator { /// Ctor. template - CUTLASS_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0) + CUTLASS_HOST_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0) : pointer(reinterpret_cast(&fragment[offset])) { static_assert(OtherFragment_::kElements >= Fragment::kElements, ""); } /// Create from non-constant FragmentIterator - CUTLASS_DEVICE FragmentConstIterator( + CUTLASS_HOST_DEVICE FragmentConstIterator( FragmentIterator const& rhs_) : pointer(reinterpret_cast(rhs_.offset)) {} /// The accessor. - CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const { + CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const { int const imm = ComputeOffsetFromStrides::get(d, h, w, c); return reinterpret_cast(pointer[imm]); } /// The accessor. - CUTLASS_DEVICE AccessType const& operator[](int i) const { + CUTLASS_HOST_DEVICE AccessType const& operator[](int i) const { return reinterpret_cast(pointer[i * kElementsPerAccess]); } /// Is the iterator valid? - CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; } + CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; } /// The pointer. Element const* pointer; diff --git a/cutlass/fragment_load_store.h b/cutlass/fragment_load_store.h deleted file mode 100644 index a7d272e9e..000000000 --- a/cutlass/fragment_load_store.h +++ /dev/null @@ -1,135 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved. - * - * Redistribution and use in source and binary forms, with or without modification, are permitted - * provided that the following conditions are met: - * * Redistributions of source code must retain the above copyright notice, this list of - * conditions and the following disclaimer. - * * Redistributions in binary form must reproduce the above copyright notice, this list of - * conditions and the following disclaimer in the documentation and/or other materials - * provided with the distribution. - * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used - * to endorse or promote products derived from this software without specific prior written - * permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR - * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND - * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, - * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; - * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, - * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/*! \file - \brief Defines accessors for loading and storing fragments to memory efficiently. -*/ -#pragma once - -#include -#include - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FragmentLoad {}; - -template -struct FragmentLoad { - /// The output type. - typedef FragmentElement_ AccessType; - - /// The load function. - static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) { - value.load(&pointer[offset], kStride); - } -}; - -template -struct FragmentLoad { - /// The output type. - typedef typename Vectorize::Type AccessType; - - /// The load function. - static CUTLASS_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) { - Load::load(value, pointer, offset); - } -}; - -template -struct FragmentStore {}; - -template -struct FragmentStore { - /// The input type. - typedef FragmentElement_ AccessType; - - /// The store function. - static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) { - value.store(&pointer[offset], kStride); - } -}; - -template -struct FragmentStore { - /// The input type. - typedef typename Vectorize::Type AccessType; - - /// The store function. - static CUTLASS_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) { - Store::store(value, pointer, offset); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} /// namespace cutlass diff --git a/cutlass/fragment_multiply_add.h b/cutlass/fragment_multiply_add.h index 36a4d6f6a..de2c8052f 100644 --- a/cutlass/fragment_multiply_add.h +++ b/cutlass/fragment_multiply_add.h @@ -27,52 +27,59 @@ */ #pragma once -#include +#include "cutlass/fragment.h" namespace cutlass { namespace gemm { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template < typename ScalarAlphaBeta_, + typename ScalarAccum_, + bool fragMul2 = true /*number of element per fragment is multiple of 2*/ +> struct FragmentMultiplyAdd { /// The shape of the instruction. typedef Shape<1, 1, 1, 1> InstructionShape; - /// The type for A. - typedef Scalar_ ScalarA; - /// The type for B. - typedef Scalar_ ScalarB; - /// The type for C and D. - typedef Scalar_ ScalarC; + /// The type for alpha and beta + typedef ScalarAlphaBeta_ ScalarAlphaBeta; + /// The type for accumlator + typedef ScalarAccum_ ScalarAccum; /// Ctor. CUTLASS_DEVICE FragmentMultiplyAdd() {} /// Multiply : d = a*b. template - CUTLASS_DEVICE void multiply(Scalar_ a, FragmentB_ const& b, FragmentCd_& d) { + CUTLASS_DEVICE void multiply(ScalarAlphaBeta a, FragmentB_ const& b, FragmentCd_& d) { +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 int const kReduction = FragmentB_::kElements / FragmentCd_::kElements; for (int j = 0; j < FragmentCd_::kElements; ++j) { - d[j] = a * b[j * kReduction + 0]; + d[j] = b[j * kReduction + 0]; for (int k = 1; k < kReduction; ++k) { - d[j] += a * b[j * kReduction + k]; + d[j] += b[j * kReduction + k]; } + d[j] = a * ScalarAlphaBeta(d[j]); } +#endif } /// Multiply : d = a*b + c. template - CUTLASS_DEVICE void multiply_add(Scalar_ a, + CUTLASS_DEVICE void multiply_add(ScalarAlphaBeta a, FragmentB_ const& b, FragmentCd_ const& c, FragmentCd_& d) { +#if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 int const kReduction = FragmentB_::kElements / FragmentCd_::kElements; for (int j = 0; j < FragmentCd_::kElements; ++j) { - d[j] = a * b[j * kReduction + 0] + c[j]; + d[j] = b[j * kReduction + 0]; for (int k = 1; k < kReduction; ++k) { - d[j] += a * b[j * kReduction + k]; + d[j] += b[j * kReduction + k]; } + d[j] = a * ScalarAlphaBeta(d[j]) + ScalarAlphaBeta(c[j]); } +#endif } }; @@ -80,15 +87,13 @@ struct FragmentMultiplyAdd { #if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16) template <> -struct FragmentMultiplyAdd { +struct FragmentMultiplyAdd { /// The shape of the instruction. - typedef Shape<1, 1, 2, 1> InstructionShape; - /// The type for A. - typedef half ScalarA; - /// The type for B. - typedef half ScalarB; - /// The type for C and D. - typedef half ScalarC; + typedef Shape<1, 1, 1, 1> InstructionShape; + /// The type for alpha and beta + typedef half ScalarAlphaBeta; + /// The type for accumlator + typedef half ScalarAccum; /// Ctor. CUTLASS_DEVICE FragmentMultiplyAdd() {} @@ -97,17 +102,19 @@ struct FragmentMultiplyAdd { template CUTLASS_DEVICE void multiply(half a, FragmentB_ const& b, FragmentCd_& d) { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 - - // Assemble a half2 from a. - __half2 const a_half2 = __half2half2(a); // The input. __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]); // The output. __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]); - int const kReduction = FragmentB_::kElements / FragmentCd_::kElements; + // Assemble a half2 from a. + __half2 const a_half2 = __half2half2(a); + + int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements); + for (int j = 0; j < FragmentCd_::kElements / 2; ++j) { d_half2[j] = __hmul2(a_half2, b_half2[j * kReduction + 0]); + for (int k = 1; k < kReduction; ++k) { d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]); } @@ -115,6 +122,7 @@ struct FragmentMultiplyAdd { #endif } + /// Multiply : d = a*b + c. template CUTLASS_DEVICE void multiply_add(half a, @@ -122,17 +130,19 @@ struct FragmentMultiplyAdd { FragmentCd_ const& c, FragmentCd_& d) { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 - // Assemble a half2 from a. - __half2 const a_half2 = __half2half2(a); // The inputs. __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]); __half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]); // The output. __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]); + // Assemble a half2 from a. + __half2 const a_half2 = __half2half2(a); + int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements); for (int j = 0; j < FragmentCd_::kElements / 2; ++j) { d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + 0], c_half2[j]); + for (int k = 1; k < kReduction; ++k) { d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]); } diff --git a/cutlass/gemm/clear_accumulators.h b/cutlass/gemm/clear_accumulators.h index 441370f4c..3a2f33752 100644 --- a/cutlass/gemm/clear_accumulators.h +++ b/cutlass/gemm/clear_accumulators.h @@ -27,7 +27,7 @@ */ #pragma once -#include +#include "cutlass/vector.h" namespace cutlass { namespace gemm { @@ -39,11 +39,12 @@ struct ClearAccumulators { /// The shared storage. struct SharedStorage {}; - /// Ctor. - CUTLASS_DEVICE ClearAccumulators() {} /// Ctor. CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {} + /// Ctor. + CUTLASS_DEVICE ClearAccumulators() {} + /// Clear the fragment. template CUTLASS_DEVICE void clear(Fragment_& fragment) { diff --git a/cutlass/gemm/dgemm_traits.h b/cutlass/gemm/dgemm_traits.h index 0bbc2210b..5c0559020 100644 --- a/cutlass/gemm/dgemm_traits.h +++ b/cutlass/gemm/dgemm_traits.h @@ -27,13 +27,13 @@ */ #pragma once -#include -#include -#include -#include -#include -#include -#include +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_epilogue.h" +#include "cutlass/gemm/gemm_epilogue_traits.h" +#include "cutlass/gemm/gemm_global_tile.h" +#include "cutlass/gemm/gemm_shared_tile.h" +#include "cutlass/gemm/gemm_traits.h" +#include "cutlass/gemm/thread_multiply_add.h" namespace cutlass { namespace gemm { @@ -41,10 +41,10 @@ namespace gemm { //////////////////////////////////////////////////////////////////////////////////////////////////// template < - /// The tile size for the GEMM KxNxM. + /// The tile size for threadblock-level GEMM (K-by-N-by-M). typename OutputTile_, - /// The number of accumulators per thread. - typename AccumulatorsPerThread_, + /// Tile size for thread-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_, /// The number of scalars per LDG for A. int kScalarsPerLdgA_ = 1, /// The number of scalars per LDG for B. @@ -62,7 +62,7 @@ struct DgemmConfig /// The tile size for the GEMM KxNxM. OutputTile_, /// The functor to do the math in the main loop. - ThreadMultiplyAdd, double, double, double>, + ThreadMultiplyAdd, double, double, double>, /// The number of scalars per LDG for A. kScalarsPerLdgA_, /// The number of scalars per STS for A. @@ -82,7 +82,14 @@ struct DgemmConfig /// The number of scalars per LDS for D. 1, /// The number of stages in shared memory. - 2> {}; + 2, + /// kResidueSeparate + false, + /// kResidueInPrologue + false, + /// kLaunchBounds + false + >{}; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -91,12 +98,12 @@ template < MatrixLayout::Kind kLayoutA_, /// The layout for B. MatrixLayout::Kind kLayoutB_, - /// The output tile. + /// The tile size for threadblock-level GEMM (K-by-N-by-M) typename OutputTile_ = Shape<8, 64, 128>, /// The functor to use in the epilogue. typename EpilogueFunctor_ = LinearScaling, - /// The number of accumulators per thread. - typename AccumulatorsPerThread_ = Shape<8, 8, 8>, + /// Tile size for thread-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_ = Shape<8, 8, 8>, /// The number of doubles loaded in one LDG for A. int kScalarsPerLdgA_ = 1, /// The number of doubles loaded in one LDG for B. @@ -105,7 +112,7 @@ template < typename Index_ = int, /// The DGEMM config. typename GemmConfig_ = - DgemmConfig, + DgemmConfig, /// The traits class for the epilogue. typename GemmEpilogueTraits_ = SimplifiedGemmEpilogueTraits > diff --git a/cutlass/gemm/fp16_sgemm_multiply_add.h b/cutlass/gemm/fp16_sgemm_multiply_add.h new file mode 100644 index 000000000..534b8c899 --- /dev/null +++ b/cutlass/gemm/fp16_sgemm_multiply_add.h @@ -0,0 +1,83 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template implementing matrix multiply-add operations on fragments. +*/ +#pragma once + +#include "cutlass/fragment.h" +#include "cutlass/gemm/thread_multiply_add.h" +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Template performing matrix multiply-add operation within a thread +template +struct ThreadMultiplyAdd { + /// The shape of the instruction. + typedef Shape<1, 1, 1, 1> InstructionShape; + /// The shape of a thread-leveel matrix multiply accumulate. + typedef ThreadGemmShape_ ThreadGemmShape; + /// Aliased to "AccumulatorsPerThread" for compatibility. Expect to be renamed in CUTLASS v2.0 + typedef ThreadGemmShape AccumulatorsPerThread; + /// The number of threads per warp. + typedef ThreadsPerWarp_ ThreadsPerWarp; + /// The number of accumulators per warp. + typedef typename ShapeMul::Shape AccumulatorsPerWarp; + /// The type for A. specialized to half + typedef half ScalarA; + /// The fragment for A. + typedef Fragment FragmentA; + /// The type for B. specialized to half + typedef half ScalarB; + /// The fragment for B. + typedef Fragment FragmentB; + /// The type for C and D. specialized to float + typedef float ScalarC; + /// The accumulators. + typedef Fragment Accumulators; + + /// Ctor. + CUTLASS_DEVICE ThreadMultiplyAdd() {} + + /// Multiply : d = a*b + c. + CUTLASS_DEVICE void multiply_add(FragmentA const& a, + FragmentB const& b, + Accumulators const& c, + Accumulators& d) { + for (int j = 0; j < AccumulatorsPerThread::kH; ++j) { + for (int i = 0; i < AccumulatorsPerThread::kW; ++i) { + d[j * AccumulatorsPerThread::kW + i] = static_cast(a[i]) * static_cast(b[j]) + c[j * AccumulatorsPerThread::kW + i]; + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/fp16_sgemm_traits.h b/cutlass/gemm/fp16_sgemm_traits.h new file mode 100644 index 000000000..361186455 --- /dev/null +++ b/cutlass/gemm/fp16_sgemm_traits.h @@ -0,0 +1,152 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defies structural properties of single-precision GEMM where any number of the input/output + could be fp16 or fp32. The accumulator type stays in fp32 +*/ +#pragma once + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_epilogue.h" +#include "cutlass/gemm/gemm_epilogue_traits.h" +#include "cutlass/gemm/gemm_global_tile.h" +#include "cutlass/gemm/gemm_shared_tile.h" +#include "cutlass/gemm/gemm_traits.h" +#include "cutlass/gemm/fp16_sgemm_multiply_add.h" + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// The tile size for the GEMM KxNxM. + typename OutputTile_, + /// Tile size for thread-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_, + /// The type for A + typename ScalarA_, + /// The type for B + typename ScalarB_, + /// The type for C + typename ScalarC_, + /// The type for D + typename ScalarD_, + /// The number of scalars per LDG for A. + int kScalarsPerLdgA_ = 1, + /// The number of scalars per LDG for B. + int kScalarsPerLdgB_ = 1> +struct Fp16SgemmConfig : public GemmConfig< + /// The scalar type for A. + ScalarA_, + /// The scalar type for B. + ScalarB_, + /// The scalar type for C. + ScalarC_, + /// The scalar type for D. + ScalarD_, + /// The tile size for the GEMM KxNxM. + OutputTile_, + /// The functor to do the math in the main loop. + ThreadMultiplyAdd, ScalarA_, ScalarB_, float /*for sgemm accum is float*/>, + /// The number of scalars per LDG for A. + kScalarsPerLdgA_, + /// The number of scalars per STS for A. + kScalarsPerLdgA_, + /// The number of scalars per LDS for A. + 4, + /// The number of scalars per LDG for B. + kScalarsPerLdgB_, + /// The number of scalars per STS for B. + kScalarsPerLdgB_, + /// The number of scalars per LDS for B. + 4, + /// The number of scalars per LDG for C and STG for D. + 1, + /// The number of scalars per STS for D. + 4, + /// The number of scalars per LDS for D. + 1, + /// The number of stages in shared memory. + 2> {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// The layout for A. + MatrixLayout::Kind kLayoutA_, + /// The layout for B. + MatrixLayout::Kind kLayoutB_, + /// The output tile. + typename OutputTile_ = Shape<8, 128, 128>, + /// The type for A + typename ScalarA_ = half, + /// The type for B + typename ScalarB_ = half, + /// The type for C + typename ScalarC_ = half, + /// The type for D + typename ScalarD_ = half, + /// the Type for alpha and beta, + typename Scalar_ = half, + /// The functor to use in the epilogue. + typename EpilogueFunctor_ = LinearScaling >, + /// Tile size for thread-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_ = Shape<8, 8, 8>, + /// The number of floats loaded in one LDG for A. + int kScalarsPerLdgA_ = 1, + /// The number of floats loaded in one LDG for B. + int kScalarsPerLdgB_ = 1, + /// The index. + typename Index_ = int, + /// The SGEMM config. + typename GemmConfig_ = + Fp16SgemmConfig, + /// The traits class for the epilogue. + typename GemmEpilogueTraits_ = + SimplifiedGemmEpilogueTraits > +struct Fp16SgemmSgemmTraits : public SimplifiedGemmTraits< + // The layout for A. + kLayoutA_, + // The layout for B. + kLayoutB_, + // The config. + GemmConfig_, + // The epilogue. + GemmEpilogue, + // The index. + Index_> {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/gemm.h b/cutlass/gemm/gemm.h index c50a3f04b..6340ab4f3 100644 --- a/cutlass/gemm/gemm.h +++ b/cutlass/gemm/gemm.h @@ -31,16 +31,17 @@ #include #endif -#include -#include - +#include "cutlass/coord.h" +#include "cutlass/util/platform.h" namespace cutlass { namespace gemm { //////////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel with launch bounds specified template -__global__ /*__launch_bounds__(Gemm_::kThreads)*/ void gemm_kernel(typename Gemm_::Params params) { +__global__ __launch_bounds__(Gemm_::kThreads) +void gemm_kernel(typename Gemm_::Params params) { // Declare shared memory. __shared__ typename Gemm_::SharedStorage shared_storage; @@ -52,28 +53,37 @@ __global__ /*__launch_bounds__(Gemm_::kThreads)*/ void gemm_kernel(typename Gemm //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct GemmDesc { - /// The dimensions of the GEMM. - Index_ m, n, k; - /// The alpha/beta scaling values. - Scalar_ alpha, beta; - /// The source matrix A. - void const* d_a; - /// The stride for A. - Index_ lda; - /// The source matrix B. - void const* d_b; - /// The stride for B. - Index_ ldb; - /// The source matrix C. - void const* d_c; - /// The stride for C. - Index_ ldc; - /// The destination matrix D. - void* d_d; - /// The stride for D. - Index_ ldd; +/// GEMM kernel without launch bounds specified +template +__global__ /* __launch_bounds__(Gemm_::kThreads) */ +void gemm_kernel_nolb(typename Gemm_::Params params) { + // Declare shared memory. + __shared__ typename Gemm_::SharedStorage shared_storage; + + // Construct the GEMM object. + Gemm_ gemm(params, shared_storage); + // Run GEMM. + gemm.multiply_add(); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for launching the GEMM kernel with or without launch bounds +template +struct Launch { + Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) { + gemm_kernel<<< grid, block, 0, stream >>>(params); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for launching the GEMM kernel with or without launch bounds +template +struct Launch { + Launch(typename Gemm::Params params, dim3 grid, dim3 block, cudaStream_t stream = 0) { + gemm_kernel_nolb<<< grid, block, 0, stream >>>(params); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -100,86 +110,52 @@ struct Gemm { /// The index. typedef typename Traits::Index Index; + /// Define the mainloop iteration size + typedef typename Traits::MultiplyAdd MultiplyAdd; + /// The number of threads. static int const kThreads = Traits::GemmConfig::kThreads; - /// The params. - struct Params : public Traits::Params { - CUTLASS_HOST_DEVICE int initialize(Index m, - Index n, - Index k, - ScalarEpilogue alpha, - ScalarA const* d_a, - Index lda, - ScalarB const* d_b, - Index ldb, - ScalarEpilogue beta, - ScalarC const* d_c, - Index ldc, - ScalarD* d_d, - Index ldd) { - GemmDesc desc; - desc.m = m; - desc.n = n; - desc.k = k; - desc.alpha = alpha; - desc.beta = beta; - desc.d_a = reinterpret_cast(d_a); - desc.lda = lda; - desc.d_b = reinterpret_cast(d_b); - desc.ldb = ldb; - desc.d_c = reinterpret_cast(d_c); - desc.ldc = ldc; - desc.d_d = reinterpret_cast(d_d); - desc.ldd = ldd; - return Traits::Params::initialize(desc); - } - }; + // Number of warp-level multiply-accumulate steps executed by each warp. + static Index const kWarpGemmSteps = + Traits::GemmConfig::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD; + // Make sure we have at least 2 unrolling steps or our pipeling is not going to work. + static_assert(kWarpGemmSteps >= 2, "The pipelining assumes at least two steps"); + + /// Use the params object defined in traits + typedef typename Traits::Params Params; + +// +// Static function members +// + +/// Support for NVRTC #if !defined(__CUDACC_RTC__) /// Launch the kernel. static __host__ cudaError_t launch(Params const& params, cudaStream_t stream = cudaStreamDefault) { - // Setup the grid. - dim3 grid; - grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW; - grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH; - - // The number of threads. - dim3 block; - block.x = kThreads; // Launch the kernel. - void const* params_ = reinterpret_cast(¶ms); + Launch( + params, params.grid, params.block, stream); - return cudaLaunchKernel(reinterpret_cast(&gemm_kernel), - grid, - block, - const_cast(¶ms_), - 0, - stream); + return cudaGetLastError(); } /// Launch the kernel. static __host__ cudaError_t launch(CUfunction kernel, Params const& params, CUstream stream = CU_STREAM_LEGACY) { - // Setup the grid. - dim3 grid; - grid.x = (params.m + Traits::OutputTile::kW - 1) / Traits::OutputTile::kW; - grid.y = (params.n + Traits::OutputTile::kH - 1) / Traits::OutputTile::kH; - - // The number of threads. - dim3 block; - block.x = kThreads; // Launch the kernel. void* params_[] = {const_cast(reinterpret_cast(¶ms))}; - // return cudaLaunchKernel(reinterpret_cast(&gemm_kernel), grid, block, - // const_cast(¶ms_), 0, stream); CUresult result = cuLaunchKernel( - kernel, grid.x, grid.y, grid.z, block.x, block.y, block.z, 0, stream, params_, 0); + kernel, + params.grid.x, params.grid.y, params.grid.z, + params.block.x, params.block.y, params.block.z, + 0, stream, params_, 0); if (result != CUDA_SUCCESS) { return cudaErrorLaunchFailure; @@ -189,39 +165,41 @@ struct Gemm { #endif + // + // Methods + // + /// Ctor. CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_) : params(params_), shared_storage(shared_storage_) {} - /// Consume a single iteration of the loop. - template - CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_stream, - typename Traits::SharedLoadStream& shared_load_stream, - typename Traits::MultiplyAdd::Accumulators& accumulators, + /// Computes a warp-level GEMM on data held in shared memory + template + CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_to_shared_stream, + typename Traits::SharedStream& shared_load_stream, + typename MultiplyAdd::Accumulators& accumulators, Index outer_k) { - // If that's the last "load iteration" update the predicates. - if (!kIsLastIteration) { - global_stream.move_to_residue(outer_k); + // If residue portion and not calculating residue in prolog, update residue predicates now. + if (Residue && outer_k <= Traits::OutputTile::kD) { + global_to_shared_stream.residue(outer_k); } - // Load data for the next iteration of the main loop. - if (!kIsLastIteration) { - global_stream.copy(); + // Load data for the next iteration of the main loop (unless it's the last iteration). + if (!LastIteration) { + global_to_shared_stream.copy(); } - // The unrolling steps for the main loop. - int const kUnrollingSteps = - Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD; - CUTLASS_PRAGMA_UNROLL - for (int step = 0; step < kUnrollingSteps - 1; ++step) { + for (int step = 0; step < kWarpGemmSteps - 1; ++step) { // Trigger the copy from shared memory for the next A/B values. shared_load_stream.copy(step + 1); + // Make sure the values are available for the current iteration to do the multiply-add. shared_load_stream.commit(step); + MultiplyAdd multiply_add; + // Do the math on the fragments of the current iteration. - typename Traits::MultiplyAdd multiply_add; multiply_add.multiply_add(shared_load_stream.fragment_a(step), shared_load_stream.fragment_b(step), accumulators, @@ -232,28 +210,25 @@ struct Gemm { Traits::shared_load_fence(true); // Commit the data in shared memory for A/B. - if (!kIsLastIteration) { - global_stream.commit(); + if (!LastIteration) { + global_to_shared_stream.commit(); } - // Make sure the data is in shared memory. Traits::shared_store_fence(true); - // Trigger the loads for the next iteration (if needed). - if (!kIsLastIteration) { + if (!LastIteration) { // Move to the next stage for the load (if it makes sense). shared_load_stream.inc_stage(); // Trigger the copy from shared memory for the next loop iteration. shared_load_stream.copy(0); } - // Make sure the values are available for the current iteration to do the multiply-add. - shared_load_stream.commit(kUnrollingSteps - 1); + shared_load_stream.commit(kWarpGemmSteps - 1); // Do the math on the fragments of the current iteration. - typename Traits::MultiplyAdd multiply_add; - multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1), - shared_load_stream.fragment_b(kUnrollingSteps - 1), + MultiplyAdd multiply_add; + multiply_add.multiply_add(shared_load_stream.fragment_a(kWarpGemmSteps - 1), + shared_load_stream.fragment_b(kWarpGemmSteps - 1), accumulators, accumulators); } @@ -262,76 +237,112 @@ struct Gemm { CUTLASS_DEVICE void multiply_add() { // Swizzle the IDs of the block (to enable better cache behavior). typename Traits::BlockSwizzle block_swizzle; - dim3 block = block_swizzle.swizzle(); - - // Scale the id. - block.x *= Traits::OutputTile::kW; - block.y *= Traits::OutputTile::kH; + Coord<3> threadblock_offset = + block_swizzle.get_threadblock_offset(make_Coord_from_shape()); // We may want to use shared memory to clear the registers. typedef typename Traits::ClearAccumulators ClearAccumulators; // The streams to read A/B from global memory to shared memory. - typename Traits::GlobalLoadStream global_stream(params, shared_storage, block); + typename Traits::GlobalLoadStream global_to_shared_stream( + params.global_to_shared_stream, + shared_storage.main_loop.global_to_shared_stream, + shared_storage.main_loop.threadblock_tile.reference(), + params.problem_size.knm(), + threadblock_offset); + + // update A and B pointer offset based on batch_id and batch_stride_offset + //global_to_shared_stream.add_pointer_offset(block_swizzle.get_batch_id(), params.batch_stride_A, params.batch_stride_B); + global_to_shared_stream += make_Coord(block_swizzle.get_batch_id(), 0, 0); // Create the accumulator clear. - ClearAccumulators clear(shared_storage.main_loop.clear); + ClearAccumulators clear; - // By how much we unroll the main loop. - Index const kUnroll = static_cast(Traits::OutputTile::kD); - - // If we do not have enough steps in the main loop, trigger the residue code. - global_stream.move_to_residue(params.k); + // Deal with residue in prolog. + global_to_shared_stream.move_to_residue(params.problem_size[0], Traits::OutputTile::kD); // Fetch the fragments for A and B from global memory. - global_stream.copy(); + global_to_shared_stream.copy(); // Copy the elements to shared memory (after transformation if needed). - global_stream.commit(); + global_to_shared_stream.commit(); // Make sure the data is in shared memory. Traits::shared_store_fence(false); - // Rollback to the beginning of the GEMM-K dimension. It may have no impact. - global_stream.rollback(); - - // The unrolling steps for the main loop. - int const kUnrollingSteps = - Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD; - - // Make sure we have at least 2 unrolling steps or our pipeling is not going to work. - static_assert(kUnrollingSteps >= 2, "The pipelining assumes at least two steps"); + // Rollback to the beginning of the first tile (if residue exists). + global_to_shared_stream.rollback(params.problem_size[0] % Traits::OutputTile::kD); // The stream of data from shared memory to fragments. - typename Traits::SharedLoadStream shared_load_stream(params, shared_storage); + typename Traits::SharedStream shared_load_stream( + params.shared_stream, + shared_storage.main_loop.threadblock_tile.reference()); // Trigger the copy from shared memory for the 1st stream. shared_load_stream.copy(0); // Allocate the accumulators. - typename Traits::MultiplyAdd::Accumulators accumulators; + typename MultiplyAdd::Accumulators accumulators; + // Clear the accumulators. clear.clear(accumulators); - // The loop index. - Index outer_k = params.k - kUnroll; + // Initial index + Index outer_k = params.problem_size[0] - Traits::OutputTile::kD; - // Enter the main loop and iterate. - for (; outer_k > 0; outer_k -= kUnroll) { - consume_tile(global_stream, shared_load_stream, accumulators, outer_k); - } + // Check if we are computing residue in prolog or not. + if (Traits::GemmConfig::kResidueInProlog) { - // Residual loop. - for (; outer_k > -kUnroll; outer_k -= kUnroll) { - consume_tile(global_stream, shared_load_stream, accumulators, outer_k); + // Execute all mainloop iterations but the last one. + + CUTLASS_GEMM_LOOP + for (; outer_k > 0; outer_k -= Traits::OutputTile::kD) { + consume_tile( + global_to_shared_stream, shared_load_stream, accumulators, outer_k); + + } + + // Don't load data for the last "residue" portion since we've already computed the residue. + CUTLASS_GEMM_LOOP + for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) { + consume_tile( + global_to_shared_stream, shared_load_stream, accumulators, outer_k); + + } + } else { + // When kResidueSeparate = true, execute all mainloop iterations but the last two without any + // consideration for K-residue or predicate updates. This improves the steady state of some + // kernels. + if (Traits::GemmConfig::kResidueSeparate) { + + CUTLASS_GEMM_LOOP + for (; outer_k > Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) { + consume_tile( + global_to_shared_stream, shared_load_stream, accumulators, outer_k); + + } + } + + // Execute remaining tiles with K-residue predicate updates enabled. + + CUTLASS_GEMM_LOOP + for (; outer_k > -Traits::OutputTile::kD; outer_k -= Traits::OutputTile::kD) { + consume_tile( + global_to_shared_stream, shared_load_stream, accumulators, outer_k); + + } } // Epilogue. typedef typename Traits::Epilogue Epilogue; - Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.m, params.n); - epilogue.epilogue(cutlass::make_Coord(0, block.y, block.x), accumulators); + Epilogue epilogue(params.epilogue, shared_storage.epilogue, params.problem_size.knm()); + epilogue.epilogue(accumulators, threadblock_offset, block_swizzle.get_batch_id()); } + // + // Data members + // + /// The params. Params const& params; /// The shared storage. diff --git a/cutlass/gemm/gemm_config.h b/cutlass/gemm/gemm_config.h new file mode 100644 index 000000000..76df0add6 --- /dev/null +++ b/cutlass/gemm/gemm_config.h @@ -0,0 +1,145 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines properties of GEMM computation that impose some constraints on caller. +*/ +#pragma once + +#include "cutlass/shape.h" + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// The scalar type for A. + typename ScalarA_, + /// The scalar type for B. + typename ScalarB_, + /// The scalar type for C. + typename ScalarC_, + /// The scalar type for D. + typename ScalarD_, + /// The threadblock tile size for the GEMM KxNxM. + typename OutputTile_, + /// The functor to do the math. + typename MultiplyAdd_, + /// The number of scalars per LDG for A. + int kScalarsPerLdgA_, + /// The number of scalars per STS for A. + int kScalarsPerStsA_, + /// The number of scalars per LDG for A. + int kScalarsPerLdsA_, + /// The number of scalars per LDG for B. + int kScalarsPerLdgB_, + /// The number of scalars per STS for B. + int kScalarsPerStsB_, + /// The number of scalars per LDS for B. + int kScalarsPerLdsB_, + /// The number of scalars per LDG for C and STG for D. + int kScalarsPerLdgCAndStgD_, + /// The number of scalars per STS for D. + int kScalarsPerStsD_, + /// The number of scalars per LDS for D. + int kScalarsPerLdsD_, + /// The number of stages in shared memory to do single/double/triple-buffering. + int kStages_, + /// If true, residue is computed in mainloop. If false, separate loops are instantiated. + bool kResidueSeparate_ = false, + /// Is residue performed in prologue? + bool kResidueInProlog_ = false, + /// If true, kernel is launched with CUDA launch bounds specified + bool kLaunchBounds_ = true> +struct GemmConfig { + // + /// The scalar for A. + typedef ScalarA_ ScalarA; + /// The scalar for B. + typedef ScalarB_ ScalarB; + /// The scalar for C. + typedef ScalarC_ ScalarC; + /// The scalar for D. + typedef ScalarD_ ScalarD; + + /// The tile. + typedef OutputTile_ OutputTile; + /// The functor to do D = A*B + C. + typedef MultiplyAdd_ MultiplyAdd; + /// The shape of the instruction. + typedef typename MultiplyAdd::InstructionShape InstructionShape; + /// The shape of warp-level GEMM + typedef typename MultiplyAdd::AccumulatorsPerWarp AccumulatorsPerWarp; + /// The accumulators. + typedef typename MultiplyAdd::Accumulators Accumulators; + + /// The number of warps. + typedef typename ShapeDiv::Shape Warps; + /// The default warp size (32 threads per warp). + static int const kWarpSize = cutlass::kWarpSize; + /// The numnber of threads. + static int const kThreads = ShapeCount::kCount * kWarpSize; + + /// The number of scalars per LDG/STS/LDS for A. + static int const kScalarsPerLdgA = kScalarsPerLdgA_; + static int const kScalarsPerStsA = kScalarsPerStsA_; + static int const kScalarsPerLdsA = kScalarsPerLdsA_; + + /// The number of scalars per LDG/STS/LDS for B. + static int const kScalarsPerLdgB = kScalarsPerLdgB_; + static int const kScalarsPerStsB = kScalarsPerStsB_; + static int const kScalarsPerLdsB = kScalarsPerLdsB_; + + /// The number of scalars per LDG for C. + static int const kScalarsPerLdgC = kScalarsPerLdgCAndStgD_; + + /// The number of scalars per STS/LDS/STG for D. + static int const kScalarsPerStgD = kScalarsPerLdgCAndStgD_; + static int const kScalarsPerStsD = kScalarsPerStsD_; + static int const kScalarsPerLdsD = kScalarsPerLdsD_; + + /// The number of accumulators that are going to be fed from one LDS A/B. + static int const kAccumulatorsPerLdsA = kScalarsPerLdsA / InstructionShape::kD; + static int const kAccumulatorsPerLdsB = kScalarsPerLdsB / InstructionShape::kD; + + /// The number of stages in shared memory to implement double, triple, more-buffering. + static int const kStages = kStages_; + + /// If true, mainloop is instantiated twice. The first instantiation contains no predicate + // updates and is more efficient for some kernels. If false, only a single mainloop is + // instantaited. + static bool const kResidueSeparate = kResidueSeparate_; + + /// If true, residue is computed in the prologue. + static bool const kResidueInProlog = kResidueInProlog_; + + /// If true, kernel is launched with launch bounds specified + static bool const kLaunchBounds = kLaunchBounds_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/gemm_coord.h b/cutlass/gemm/gemm_coord.h new file mode 100644 index 000000000..8e36bb043 --- /dev/null +++ b/cutlass/gemm/gemm_coord.h @@ -0,0 +1,203 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GemmCoord is a structure derived from Coord<4> that specifies a location within the + coordinate system of a GEMM problem. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/coord.h" +#include "cutlass/util/platform.h" + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// GemmCoord is a structure derived from Coord<4> that specifies a location within the +/// coordinate space of a GEMM problem. +struct GemmCoord : public Coord<4, int> { + + /// Integer-valued index + typedef int Index; + + /// Base type is a Coord of rank=4 + typedef Coord<4, Index> Base; + + /// GEMM K dimension - inner dimension of the GEMM problem + static int const kK = 0; + + /// GEMM N dimension - columns of the output C matrix + static int const kN = 1; + + /// GEMM M dimension - rows of the output C matrix + static int const kM = 2; + + /// Batch dimension - for generalizing to larger problems + static int const kBatch = 3; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + GemmCoord() { } + + /// Constructs from Coord<3> and a batch + CUTLASS_HOST_DEVICE + GemmCoord(Coord<3, Index> const &coord, Index _batch = 0): Base(make_Coord(coord[0], coord[1], coord[2], _batch)) { } + + /// Constructs from Coord<4> + CUTLASS_HOST_DEVICE + GemmCoord(Coord<4, Index> const &coord): Base(coord) { } + + /// Constructs from an array of coordinate elements + CUTLASS_HOST_DEVICE + GemmCoord(Index coord[4]): Base(coord) { } + + /// Helper to construct from a K, N, M, batch variables + CUTLASS_HOST_DEVICE + GemmCoord(Index k, Index n, Index m, Index batch = 0): Base(make_Coord(k, n, m, batch)) { } + + /// Returns the GEMM M coordinate + CUTLASS_HOST_DEVICE + Index const & m() const { return this->at(kM); } + + /// Returns reference to the GEMM M coordinate + CUTLASS_HOST_DEVICE + Index & m() { return this->at(kM); } + + /// Returns the GEMM N coordinate + CUTLASS_HOST_DEVICE + Index const & n() const { return this->at(kN); } + + /// Returns reference to the GEMM N coordinate + CUTLASS_HOST_DEVICE + Index & n() { return this->at(kN); } + + /// Returns the GEMM K coordinate + CUTLASS_HOST_DEVICE + Index const & k() const { return this->at(kK); } + + /// Returns reference to the GEMM K coordinate + CUTLASS_HOST_DEVICE + Index & k() { return this->at(kK); } + + /// Returns the GEMM batch coordinate + CUTLASS_HOST_DEVICE + Index const & batch() const { return this->at(kBatch); } + + /// Returns reference to the GEMM batch coordinate + CUTLASS_HOST_DEVICE + Index & batch() { return this->at(kBatch); } + + /// Obtains a Coord<3> from GemmCoord + CUTLASS_HOST_DEVICE + Coord<3> knm() const { + return make_Coord(k(), n(), m()); + } + + /// Obtains a Coord<2> from GemmCoord + CUTLASS_HOST_DEVICE + Coord<2> nm() const { + return make_Coord(n(), m()); + } + + /// Obtains a Coord<2> from GemmCoord + CUTLASS_HOST_DEVICE + Coord<2> km() const { + return make_Coord(k(), m()); + } + + /// Obtains a Coord<2> from GemmCoord + CUTLASS_HOST_DEVICE + Coord<2> kn() const { + return make_Coord(k(), n()); + } + + // + // Coord operators + // + + /// Element-wise addition + CUTLASS_HOST_DEVICE + GemmCoord operator+(Base const& b) const { + return GemmCoord(Base::operator+(b)); + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + GemmCoord operator-(Base const& b) const { + return GemmCoord(Base::operator-(b)); + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + GemmCoord operator*(Base const& b) const { + return GemmCoord(Base::operator*(b)); + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + GemmCoord operator/(Base const& b) const { + return GemmCoord(Base::operator/(b)); + } + + /// In-place addition + CUTLASS_HOST_DEVICE + GemmCoord& operator+=(Base const& b) { + Base::operator+=(b); + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + GemmCoord& operator-=(Base const& b) { + Base::operator-=(b); + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + GemmCoord& operator*=(Base const& b) { + Base::operator*=(b); + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + GemmCoord& operator/=(Base const& b) { + Base::operator/=(b); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/gemm_desc.h b/cutlass/gemm/gemm_desc.h new file mode 100644 index 000000000..80f4b3655 --- /dev/null +++ b/cutlass/gemm/gemm_desc.h @@ -0,0 +1,205 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements a software-pipelined efficient GEMM. +*/ +#pragma once + +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/gemm_coord.h" + +namespace cutlass { +namespace gemm { + +/// GEMM problem description +template < + /// Source accumulator matrix type + typename AType_, + /// Destination accumulator type + typename BType_, + /// Source accumulator matrix type + typename CType_, + /// Destination accumulator type + typename DType_, + /// Scalar type for alpha and beta + typename SType_, + /// Index type for dimensions and strides + typename Index_ = int +> struct GemmDesc { + // + // Type definitions + // + + /// Index type for dimensions and strides + typedef Index_ Index; + + /// Source accumulator matrix type + typedef AType_ AType; + + /// Tensor reference to A operand + typedef TensorRef TensorRefA; + + /// Destination accumulator type + typedef BType_ BType; + + /// Tensor reference to B operand + typedef TensorRef TensorRefB; + + /// Source accumulator matrix type + typedef CType_ CType; + + /// Tensor reference to C operand + typedef TensorRef TensorRefC; + + /// Destination accumulator type + typedef DType_ DType; + + /// Tensor reference to D operand + typedef TensorRef TensorRefD; + + /// Scalar type for alpha and beta + typedef SType_ SType; + + // + // Data members + // + + /// The dimensions of the GEMM. + GemmCoord problem_size; + + /// The alpha scaling values. + SType alpha; + + /// The source matrix A. + TensorRefA A; + + /// batch stride for A operand + long long batch_stride_A; + + /// The source matrix B. + TensorRefB B; + + /// batch stride for B operand + long long batch_stride_B; + + /// The beta scaling values. + SType beta; + + /// The source matrix C. + TensorRefC C; + + /// batch stride for C operand + long long batch_stride_C; + + /// The destination matrix D. + TensorRefD D; + + /// batch stride for D operand + long long batch_stride_D; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + GemmDesc(): problem_size(0, 0, 0, 1), alpha(1), beta(0) {} + + /// Constructor for basic GEMM with batch count = 1 + CUTLASS_HOST_DEVICE + GemmDesc(Coord<3> _problem_size, + SType _alpha, + TensorRefA const &_A, + TensorRefB const &_B, + SType _beta, + TensorRefC const &_C, + TensorRefD const &_D + ): + problem_size(_problem_size[0], _problem_size[1], _problem_size[2], 1), + alpha(_alpha), + A(_A), + batch_stride_A(0), + B(_B), + batch_stride_B(0), + beta(_beta), + C(_C), + batch_stride_C(0), + D(_D), + batch_stride_D(0) {} + + /// Constructor for basic GEMM with batch count = 1 + CUTLASS_HOST_DEVICE + GemmDesc(GemmCoord _problem_size, + SType _alpha, + TensorRefA const &_A, + TensorRefB const &_B, + SType _beta, + TensorRefC const &_C, + TensorRefD const &_D + ): + problem_size(_problem_size.k(), _problem_size.n(), _problem_size.m(), 1), + alpha(_alpha), + A(_A), + batch_stride_A(0), + B(_B), + batch_stride_B(0), + beta(_beta), + C(_C), + batch_stride_C(0), + D(_D), + batch_stride_D(0) { + + assert(_problem_size.batch() == 1); + } + + /// Constructor for strided batch GEMM GEMM + CUTLASS_HOST_DEVICE + GemmDesc(GemmCoord _problem_size, + SType _alpha, + TensorRefA const &_A, + long long _batch_stride_A, + TensorRefB const &_B, + long long _batch_stride_B, + SType _beta, + TensorRefC const &_C, + long long _batch_stride_C, + TensorRefD const &_D, + long long _batch_stride_D + ): + problem_size(_problem_size), + alpha(_alpha), + A(_A), + batch_stride_A(_batch_stride_A), + B(_B), + batch_stride_B(_batch_stride_B), + beta(_beta), + C(_C), + batch_stride_C(_batch_stride_C), + D(_D), + batch_stride_D(_batch_stride_D) {} +}; + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/gemm_epilogue.h b/cutlass/gemm/gemm_epilogue.h index bc2530777..d9469bb55 100644 --- a/cutlass/gemm/gemm_epilogue.h +++ b/cutlass/gemm/gemm_epilogue.h @@ -29,26 +29,15 @@ */ #pragma once -#include -#include -#include +#include "cutlass/convert.h" +#include "cutlass/coord.h" +#include "cutlass/fragment.h" namespace cutlass { namespace gemm { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -CUTLASS_DEVICE bool is_zero(T x) { - return x == T(0); -} - -#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16) -CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast(x) == int16_t(0); } -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template struct GemmEpilogue { /// The traits class. @@ -85,9 +74,7 @@ struct GemmEpilogue { /// The shared store transformer for D. typedef typename Traits::SharedStoreTransformerD SharedStoreTransformerD; /// The iterator to load D in shared memory. - typedef typename Traits::SharedLoadIteratorD SharedLoadIteratorD; - /// The shared load transformer for D. - typedef Copy SharedLoadTransformerD; + typedef typename Traits::SharedLoadStreamD SharedLoadStreamD; /// The index. typedef typename Traits::Index Index; @@ -100,33 +87,28 @@ struct GemmEpilogue { /// Ctor. CUTLASS_DEVICE GemmEpilogue(Params const& params_, SharedStorage& shared_storage_, - Index m_, - Index n_) - : params(params_), shared_storage(shared_storage_), m(m_), n(n_) {} + Coord<3> const& _problem_size) + : params(params_), shared_storage(shared_storage_), problem_size(_problem_size), functor(params_.functor) {} /// Execute the epilogue. - CUTLASS_DEVICE void epilogue(Coord<3> const& block, Accumulators& accumulators) { - if (is_zero(params.functor.beta)) { - epilogue_with_or_without_beta(block, accumulators); + CUTLASS_DEVICE void epilogue(Accumulators& accumulators, + Coord<3> const& block = make_Coord(0, 0, 0), + int batch_id = 0) { + if (functor.source_required()) { + epilogue_with_or_without_beta(accumulators, block, batch_id); } else { - epilogue_with_or_without_beta(block, accumulators); + epilogue_with_or_without_beta(accumulators, block, batch_id); } } - template - CUTLASS_DEVICE void epilogue_with_or_without_beta(Coord<3> const& block, - Accumulators& accumulators) { - - // The problem size. - Coord<3> const bounds = cutlass::make_Coord(0, n, m); - - // The functor. - Functor functor(params.functor); + template + CUTLASS_DEVICE void epilogue_with_or_without_beta(Accumulators& accumulators, + Coord<3> const& block, + int batch_id) { // The C fragment. typename GlobalLoadIteratorC::Fragment fragment_c; // The transformed C fragment. typename GlobalTransformerC::OutputFragment transformed_c; - CUTLASS_PRAGMA_UNROLL for (int h = 0; h < Iterations::kH; ++h) { // Compute pointer and predicate offsets for C and D global iterators. @@ -136,6 +118,7 @@ struct GemmEpilogue { Iterations::kW + params.stride_h) * h; + int const predicate_offset = ((params.iterator_d.predicate_inc_h * (GlobalStoreIteratorD::Iterations::kH - 1) + params.iterator_d.predicate_inc_advance) * @@ -145,32 +128,40 @@ struct GemmEpilogue { // The iterator to load the elements of the C matrix. GlobalLoadIteratorC global_load_iterator( - params.iterator_c, bounds, block, pointer_offset, predicate_offset); + params.iterator_c, problem_size, block, pointer_offset, predicate_offset); + + // update C pointer offset based on batch_id and batch_stride_offset + //global_load_iterator.add_pointer_offset(batch_id * params.batch_stride_offset_c); + global_load_iterator += make_Coord(batch_id, 0, 0); + // The transformer for C. GlobalTransformerC transformer_c; // The transformer for D. GlobalTransformerD transformer_d; // The iterator to store into the D matrix. GlobalStoreIteratorD global_store_iterator( - params.iterator_d, bounds, block, pointer_offset, predicate_offset); + params.iterator_d, problem_size, block, pointer_offset, predicate_offset); + + // update D pointer offset based on batch_id and batch_stride_offset + //global_store_iterator.add_pointer_offset(batch_id * params.batch_stride_offset_d); + global_store_iterator += make_Coord(batch_id, 0, 0); - // The transformer to transform before storing to shared memory. SharedStoreTransformerD shared_store_transformer; typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d; - // The iterator to store to shared memory. - SharedStoreIteratorD shared_store_iterator(params.shared_store_iterator_d, - shared_storage.shared_stream.store); + SharedStoreIteratorD shared_store_iterator( + params.shared_store_iterator_d, + reinterpret_cast(shared_storage.data())); - // The iterator to load from shared memory. TODO: Use a stream. - SharedLoadIteratorD shared_load_iterator(params.shared_load_iterator_d, - shared_storage.shared_stream.load); + SharedLoadStreamD shared_load_stream( + params.shared_load_stream_d, + reinterpret_cast(shared_storage.data())); CUTLASS_PRAGMA_UNROLL for (int w = 0; w < Iterations::kW; ++w) { // Load the C matrix into fragment. - if (!kBetaIsZero_) { - iterator_load(global_load_iterator, fragment_c); + if (kSourceRequired) { + global_load_iterator.load_post_increment(fragment_c); } // Make sure we can write to shared memory. @@ -180,33 +171,33 @@ struct GemmEpilogue { int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements; shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d); - shared_iterator_store(shared_store_iterator, shared_store_transformed_d); + shared_store_iterator.store_post_increment(shared_store_transformed_d); // Make sure the data is in shared memory. shared_store_fence(); // Copy the accumulators back to registers from shared memory. - typename SharedLoadIteratorD::Fragment fetched_d; - shared_iterator_load(shared_load_iterator, fetched_d); + shared_load_stream.copy(); + shared_load_stream.commit(); // Do the math. typename GlobalTransformerD::InputFragment fragment_d; - if (kBetaIsZero_) { - functor.evaluate(fetched_d, fragment_d); - } else { + if (kSourceRequired) { // Transform C fragment. transformer_c.transform(fragment_c, transformed_c); // Do the math. - functor.evaluate(fetched_d, transformed_c, fragment_d); + functor.evaluate(shared_load_stream.fragment(), transformed_c, fragment_d); + } else { + functor.evaluate(shared_load_stream.fragment(), fragment_d); } // Transform D fragment. - typename GlobalTransformerD::OutputFragment transformed_d; - transformer_d.transform(fragment_d, transformed_d); + typename GlobalTransformerD::OutputFragment global_transformed_d; + transformer_d.transform(fragment_d, global_transformed_d); // Copy the results to global memory. - iterator_store(global_store_iterator, transformed_d); + global_store_iterator.store_post_increment(global_transformed_d); } } } @@ -222,7 +213,9 @@ struct GemmEpilogue { /// The shared storage. SharedStorage& shared_storage; /// The dimensions of the GEMM. - Index m, n; + Coord<3> problem_size; + // The functor. + Functor functor; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass/gemm/gemm_epilogue_traits.h b/cutlass/gemm/gemm_epilogue_traits.h index c06fc2502..c6aff71e1 100644 --- a/cutlass/gemm/gemm_epilogue_traits.h +++ b/cutlass/gemm/gemm_epilogue_traits.h @@ -27,13 +27,13 @@ */ #pragma once -#include -#include -#include -#include -#include -#include -#include +#include "cutlass/convert.h" +#include "cutlass/coord.h" +#include "cutlass/gemm/gemm_global_stream.h" +#include "cutlass/gemm/gemm_shared_stream.h" +#include "cutlass/gemm/linear_scaling.h" +#include "cutlass/reshape_tile.h" +#include "cutlass/tile_iterator.h" namespace cutlass { namespace gemm { @@ -57,8 +57,8 @@ template < typename SharedStoreIteratorD_, /// The shared store transformer for D. typename SharedStoreTransformerD_, - /// The iterator to load D from shared memory. - typename SharedLoadIteratorD_, + /// The stream to load D from shared memory. + typename SharedLoadStreamD_, /// The number of iterations in the epilogue. typename Iterations_, /// The iterations strides. @@ -86,8 +86,8 @@ struct GemmEpilogueTraits { typedef SharedStoreIteratorD_ SharedStoreIteratorD; /// The shared store transformer for D. typedef SharedStoreTransformerD_ SharedStoreTransformerD; - /// The iterator to store D in shared memory. - typedef SharedLoadIteratorD_ SharedLoadIteratorD; + /// The stream to store D in shared memory. + typedef SharedLoadStreamD_ SharedLoadStreamD; /// typedef typename GemmConfig::EpilogueIterations Iterations; typedef Iterations_ Iterations; /// The iterations strides. @@ -118,14 +118,15 @@ struct GemmEpilogueTraits { typename GlobalStoreIteratorD::Params iterator_d; /// The params for the D shared store iterator. typename SharedStoreIteratorD::Params shared_store_iterator_d; - /// The params for the D shared load iterator. - typename SharedLoadIteratorD::Params shared_load_iterator_d; + /// The params for the D shared load stream. + typename SharedLoadStreamD::Params shared_load_stream_d; /// The functor params. typename Functor::Params functor; /// Setup the params. template CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) { + // The parameters for the functor. int error_code = functor.initialize(desc); if (error_code) { @@ -133,20 +134,27 @@ struct GemmEpilogueTraits { } // At the end of the H iteration, we jump over a number of columns. - this->stride_h = desc.ldd * Delta::kH; + this->stride_h = desc.D.leading_dim() * Delta::kH; // Nothing to do here. this->stride_w = 0; - // Setup the params for the global memory iterator for C. - error_code = iterator_c.initialize( - reinterpret_cast(desc.d_c), desc.ldc, desc.n, stride_w, Delta::kW); + error_code = iterator_c.initialize(desc.C.data(), + desc.batch_stride_C, + desc.C.leading_dim(), + desc.problem_size[1], + stride_w, + Delta::kW); if (error_code) { return error_code; } // Setup the params for the global memory iterator for D. - return iterator_d.initialize( - reinterpret_cast(desc.d_d), desc.ldd, desc.n, stride_w, Delta::kW); + return iterator_d.initialize(desc.D.data(), + desc.batch_stride_D, + desc.D.leading_dim(), + desc.problem_size[1], + stride_w, + Delta::kW); } }; @@ -155,13 +163,20 @@ struct GemmEpilogueTraits { // The storage for the store iterator. typename SharedStoreIteratorD::SharedStorage store; // The storage for the store iterator. - typename SharedLoadIteratorD::SharedStorage load; + typename SharedLoadStreamD::SharedStorage load; }; /// The shared memory to swizzle the data in the epilogue. struct SharedStorage { // The storage for the shared stream D. StreamSharedStorage shared_stream; + + // + // + // + + CUTLASS_DEVICE + ScalarD* data() { return reinterpret_cast(&shared_stream.load); } }; }; @@ -192,7 +207,10 @@ struct GemmEpilogueTraitsHelper { /// The traits class to build the iterator to store to shared memory for D. typedef GemmSharedStoreTileDTraits< // The pointer is float. - typename Functor::Scalar, + // typename Functor::Scalar, + // Functor::Scalar is alpha, beta type, in mixed precision, alpha and beta may not be the same with accumulation. + // In this case Functor::ScalarAccum is needed + typename Functor::ScalarAccum, // The output tile size. typename GemmConfig_::OutputTile, // The number of warps. @@ -221,7 +239,10 @@ struct GemmEpilogueTraitsHelper { /// The traits class to build the iterator to load from shared memory for D. typedef GemmSharedLoadTileDTraits< // The pointer is float. - typename Functor::Scalar, + // typename Functor::Scalar, + // Functor::Scalar is alpha, beta type, in mixed precision, alpha and beta may not be the same with accumulation. + // In this case Functor::ScalarAccum is needed + typename Functor::ScalarAccum, // The output tile size. typename GemmConfig_::OutputTile, // The number of warps. @@ -242,6 +263,8 @@ struct GemmEpilogueTraitsHelper { IteratorAdvance::kH, MemorySpace::kShared> SharedLoadIteratorD; + /// The stream to load D. + typedef SharedLoadStream SharedLoadStreamD; /// The traits class to build the iterator to load data from global memory for C^N. typedef GemmGlobalTileCdTraits< @@ -314,8 +337,8 @@ struct SimplifiedGemmEpilogueTraits : public GemmEpilogueTraits< typename Helper_::SharedStoreIteratorD, // The shared store transformer for D. typename Helper_::SharedStoreTransformerD, - // The iterator to load D from shared memory. - typename Helper_::SharedLoadIteratorD, + // The stream to load D from shared memory. + typename Helper_::SharedLoadStreamD, // The number of iterations. typename Helper_::Iterations, // The strides between iterations. diff --git a/cutlass/gemm/gemm_global_stream.h b/cutlass/gemm/gemm_global_stream.h index ec675a38f..6ea72cf30 100644 --- a/cutlass/gemm/gemm_global_stream.h +++ b/cutlass/gemm/gemm_global_stream.h @@ -29,9 +29,10 @@ */ #pragma once -#include -#include -#include +#include "cutlass/coord.h" +#include "cutlass/convert.h" +#include "cutlass/gemm/gemm_global_tile.h" +#include "cutlass/tile_allocation.h" namespace cutlass { namespace gemm { @@ -39,6 +40,8 @@ namespace gemm { //////////////////////////////////////////////////////////////////////////////////////////////////// template < + /// Identifies multiplicand + GemmOperand::Kind Operand, /// The load iterator. typename LoadIterator_, /// The store iterator to copy to shared memory. @@ -46,7 +49,9 @@ template < /// The transformer to be applied after the data has been copied from global memory. typename Transformer_> -struct GlobalLoadStreamBase { +struct GlobalLoadStream { + /// Indicates the type of GEMM operand + static GemmOperand::Kind const kOperand = Operand; /// The load iterator. typedef LoadIterator_ LoadIterator; /// The transformer. @@ -75,6 +80,15 @@ struct GlobalLoadStreamBase { typedef typename LoadIterator::Pointer Pointer; /// The index. typedef typename LoadIterator::Index Index; + /// The tile + typedef typename LoadIterator::Tile Tile; + + /// Shared memory allocation for the tile + typedef TileAllocation + ThreadblockTileStorage; + + /// Tensor reference to threadblock tile + typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef; /// The params. struct Params { @@ -82,56 +96,73 @@ struct GlobalLoadStreamBase { typename LoadIterator::Params load_iterator; // The store iterator. typename StoreIterator::Params store_iterator; + // Offset to residue. + Index offset_to_residue; /// Setup the params. - template - CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Pointer pointer, Index ld) { - int error_code = load_iterator.initialize(desc, pointer, ld); + CUTLASS_HOST_DEVICE int initialize(Pointer pointer, + long long batch_stride, + Index ldm, + Index _offset_to_residue) { + + offset_to_residue = _offset_to_residue; + int error_code = load_iterator.initialize(pointer, batch_stride, ldm); if (error_code) { return error_code; } - return store_iterator.initialize(); } }; - /// The amount of storage in shared memory needed to store the tile. - typedef typename StoreIterator::SharedStorage SharedStoreStorage; + /// Contains private storage in shared memory needed by the objects within this class. Note, + /// this is *NOT* the shared memory allocation for the GEMM threadblock tile. That necessarily + /// exists outside this class, as it is also needed by the warp-level shared=>RF stream. + struct SharedStorage {}; - /// The storage in shared memory needed by that stream. - union SharedStorage { - // The load iterator. - typename LoadIterator::SharedStorage load_iterator; - // The store iterator. - SharedStoreStorage store_iterator; - }; + // + // Static member functions + // + + /// Maps a coordinate in the GEMM's (K, N, M) coordinate system to global memory + CUTLASS_DEVICE static Coord<3> project_coordinate(Coord<3> const& coord, Index d_offset = 0) { + bool const kKstrided = + GemmMultiplicandTraits::kKstrided; + Coord<3> tile_coord = ProjectOperand::project(coord); + return make_Coord( + tile_coord[0] + d_offset, tile_coord[1], tile_coord[2] / LoadIterator::Tile::kC); + } /// Ctor. - CUTLASS_DEVICE GlobalLoadStreamBase(Params const& params, - SharedStorage& shared_storage, - Coord<3> const bounds, - Coord<3> const& block) - : load_iterator(params.load_iterator, bounds, block), + CUTLASS_DEVICE GlobalLoadStream( + Params const& _params, + SharedStorage& shared_storage, + ThreadblockTileRef const& threadblock_tile_ref, + Coord<3> const bounds, + Coord<3> const& _threadblock_offset) + : params(_params), + multiplicand_bounds(project_coordinate(bounds, 1)), + threadblock_offset(project_coordinate(_threadblock_offset)), + load_iterator(params.load_iterator, + project_coordinate(bounds, 1), /*multiplicant_bounds*/ + project_coordinate(_threadblock_offset) /*threablock_offset*/), transformer(), - store_iterator(params.store_iterator, shared_storage.store_iterator) - + store_iterator(params.store_iterator, threadblock_tile_ref.data()) { + load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset); fetched_fragment.clear(); } + /// Load the data from shared memory to the fetch fragment. - CUTLASS_DEVICE void copy() { iterator_load(load_iterator, fetched_fragment); } + CUTLASS_DEVICE void copy() { load_iterator.load_post_increment(fetched_fragment); } /// Commit the data. CUTLASS_DEVICE void commit() { transformer.transform(fetched_fragment, transformed_fragment); - iterator_store(store_iterator, transformed_fragment); + store_iterator.store_post_increment(transformed_fragment); store_iterator.inc_stage(); } - /// Move to the beginning of the residue code. That's a new code path in CUTLASS 1.0.1. - CUTLASS_DEVICE void move_to_residue(Index k) { load_iterator.move_to_residue(k); } - /// Execute the residue code. CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) { load_iterator.residue(k); @@ -140,9 +171,43 @@ struct GlobalLoadStreamBase { } } - /// Rollback to the beginning of the GEMM-k dimension. - CUTLASS_DEVICE void rollback() { load_iterator.rollback(); } + /// Move to the residue portion. + CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) { + Index kResidue = k % kTileK; + if (kResidue) { + residue(kResidue); + } + load_iterator.add_pointer_offset(params.offset_to_residue * load_iterator.stride_advance()); + } + /// Rollback to the beginning of the first tile + CUTLASS_DEVICE void rollback(void) { + load_iterator.initialize_predicates(multiplicand_bounds, threadblock_offset); + + int const kBlock = kOperand == GemmOperand::kA + ? (kLayout == MatrixLayout::kColumnMajor ? Tile::kH : Tile::kW) + : (kLayout == MatrixLayout::kRowMajor ? Tile::kH : Tile::kW); + + load_iterator.add_pointer_offset(-(params.offset_to_residue + kBlock) * + load_iterator.stride_advance()); + } + + /// Adds a Coord<3> to the underlying global load iterator + CUTLASS_DEVICE GlobalLoadStream &operator+=(Coord<3> const &offset) { + load_iterator += offset; + return *this; + } + + // + // Data members + // + + /// Parameters + Params params; + /// Multiplicand bounds + Coord<3> multiplicand_bounds; + /// Threadblock offset + Coord<3> threadblock_offset; /// The iterator. LoadIterator load_iterator; /// The fragment to fetch from shared memory. @@ -155,28 +220,6 @@ struct GlobalLoadStreamBase { StoreIterator store_iterator; }; -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// The load iterator. - typename LoadIterator_, - /// The store iterator to copy to shared memory. - typename StoreIterator_, - /// The transformer to be applied after the data has been copied from global memory. - typename Transformer_ = Copy > - -struct GlobalLoadStream : public GlobalLoadStreamBase { - /// The base class. - typedef GlobalLoadStreamBase Base; - - /// Ctor. - CUTLASS_DEVICE GlobalLoadStream(typename Base::Params const& params, - typename Base::SharedStorage& shared_storage, - Coord<3> const& bounds, - Coord<3> const& block) - : Base(params, shared_storage, bounds, block) {} -}; - //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace gemm } // namespace cutlass diff --git a/cutlass/gemm/gemm_global_tile.h b/cutlass/gemm/gemm_global_tile.h index 1cc3b3377..a355ebea0 100644 --- a/cutlass/gemm/gemm_global_tile.h +++ b/cutlass/gemm/gemm_global_tile.h @@ -27,14 +27,14 @@ */ #pragma once -#include -#include +#include "cutlass/coord.h" +#include "cutlass/util/platform.h" -#include -#include -#include -#include -#include +#include "cutlass/gemm/gemm_operand.h" +#include "cutlass/matrix_traits.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/reshape_tile.h" +#include "cutlass/tile_iterator.h" namespace cutlass { namespace gemm { @@ -80,20 +80,24 @@ struct GemmGlobalTileTraits { static int const kAccessSize = kAccessSize_; /// The memory space. static MemorySpace::Kind const kMemorySpace = MemorySpace::kGlobal; - /// The tile shape - typedef typename ReshapeTile::Tile Tile; + typedef Tile_ Tile; + /// The vectorized tile shape + typedef typename ReshapeTile::Tile VectorizedTile; /// The threads shape - typedef typename ReshapeThreads::Threads Threads; + typedef typename ReshapeThreads::Threads Threads; /// The relative offset between two elements in the H/W dimension in adjacent threads. - typedef Shape<1, 1, Tile::kC> ThreadsDelta; - + typedef Shape<1, 1, VectorizedTile::kC> ThreadsDelta; /// The strides in each dimension between different loads/stores. typedef Shape<0, Threads::kH, Threads::kW * kAccessSize> Delta; + /// Strides for immediate offset computation typedef Shape<0, 0, Threads::kW * ThreadsDelta::kW, kAccessSize> ImmediateOffsetStrides; /// The number of iterations needed to load/store the tile. - typedef Shape<1, Tile::kH / Threads::kH, Tile::kW / Threads::kW, Tile::kC / kAccessSize> + typedef Shape<1, + VectorizedTile::kH / Threads::kH, + VectorizedTile::kW / Threads::kW, + VectorizedTile::kC / kAccessSize> Iterations; typedef GemmMultiplicandTraits MultiplicandTraits; @@ -165,7 +169,6 @@ struct GemmGlobalIteratorAb Index_> { /// This class. typedef GemmGlobalIteratorAb This_; /// The base class. - typedef TileLoadIterator - CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Scalar const* ptr, Index stride_h) { + CUTLASS_HOST_DEVICE int initialize(Scalar const* ptr, + long long stride_d, + Index stride_h) { Index inc_d = 0; Index inc_advance = 0; // Move by some columns for each iteration in the H dimension. @@ -221,99 +227,36 @@ struct GemmGlobalIteratorAb (Base::Iterations::kH - 1) * inc_h; } - // The dimensions of the tile. - int const kH = TileTraits_::Tile::kH; - int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize; - - // Move to the residue. - Index const kBlock = kAdvance == IteratorAdvance::kH ? kH : kW; - // The jump in the gemm-k dimension. - Index const stride = kAdvance == IteratorAdvance::kH ? stride_h : 1; - - // Compute the offset to the residue and how to "come" back. - Index const kResidue = desc.k % kBlock; - if (kResidue > 0) { - move_to_residue_offset = (desc.k - kResidue) * stride; - } else { - move_to_residue_offset = (desc.k - kBlock) * stride; - } - - Base::Params::initialize(ptr, 0, stride_h, 1, inc_d, inc_h, 0, inc_advance); + Base::Params::initialize( + ptr, stride_d, stride_h, 1, inc_d, inc_h, 0, inc_advance); return 0; } - - // The extra offset to control moving to the residue. - Index move_to_residue_offset; }; - /// Ctor. - CUTLASS_DEVICE GemmGlobalIteratorAb(Params const& _params, - const Coord<3>& bounds, - const Coord<3>& block, - ThreadOffset thread_offset_func = ThreadOffset()) - : params(_params) { - thread_offset = thread_offset_func(); - // The column. - Index block_h = thread_offset[1]; - // The contiguous dimension. - Index block_w = thread_offset[2]; + /// Offset of an individual lane from the start of the tile + Coord<4> thread_offset; + /// The parameters + Params params; + /// The predicates. + PredicateVector predicates; - // Add the blocks indices. - if (kAdvance == IteratorAdvance::kH) { - block_h += block[1]; - block_w += block[2]; - - } else { - block_h += block[2]; - block_w += block[1]; - } - - // Setup the pointer. - params.pointer += (block_h * params.stride_h + block_w); - - // Initialize predicates - initialize_predicates(bounds, make_Coord(0, block_h, block_w)); - } - - /// The accessor. - CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const { - int const imm = - ComputeOffsetFromStrides::get(0, 0, w, c); - Load::load(value, params.pointer, imm); - } - - /// Increment the pointer in the H dimension. - CUTLASS_DEVICE void inc_h() { params.pointer += params.inc_h; } - /// Increment the pointer in the D dimension. - CUTLASS_DEVICE void inc_d() { params.pointer += params.inc_d; } - /// Increment the pointer to move to the next iteration. - CUTLASS_DEVICE void inc_advance() { params.pointer += params.inc_advance; } - - /// Initialize the predicates. - CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block) { + CUTLASS_HOST_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block_offset) { // Setup the masks to control loads. predicates.fill(0); - int bounds_h, bounds_w; - if (kAdvance == IteratorAdvance::kH) { - bounds_w = bounds[2] - block[2]; - bounds_h = bounds[1]; - - } else { - bounds_w = bounds[1]; - bounds_h = bounds[2] - block[1]; - } - // Fill in the bits of the predicate vector. for (int d = 0; d < Base::Iterations::kD; ++d) { for (int h = 0; h < Base::Iterations::kH; ++h) { for (int w = 0; w < Base::Iterations::kW; ++w) { for (int c = 0; c < Base::Iterations::kC; ++c) { - bool flag = w * Base::Delta::kW < bounds_w; + bool flag = w * Base::Delta::kW + thread_offset[2] + block_offset[2] < bounds[2]; if (kAdvance == IteratorAdvance::kH) { - flag = flag && (h * Base::Delta::kH + d * Base::Delta::kD) < bounds_h; + flag = + flag && + (h * Base::Delta::kH + d * Base::Delta::kD) + thread_offset[1] + block_offset[1] < + bounds[1]; } else { - flag = flag && (h * Base::Delta::kH) < bounds_h; + flag = flag && (h * Base::Delta::kH) + thread_offset[1] + block_offset[1] < bounds[1]; } int const bit = ComputeOffsetFromShape::get(d, h, w, c); predicates.set(bit, flag); @@ -323,31 +266,44 @@ struct GemmGlobalIteratorAb } } - /// Move to residue portion. - CUTLASS_DEVICE void move_to_residue(Index k) { - // Store the pointer and the predicates. - stored_pointer = params.pointer; - stored_predicates = predicates; + /// Ctor. + CUTLASS_HOST_DEVICE GemmGlobalIteratorAb(Params const& _params, + const Coord<3>& bounds, + const Coord<3>& threadblock_offset, + ThreadOffset thread_offset_func = ThreadOffset()) + : params(_params) { + thread_offset = thread_offset_func(); + // Setup the pointer. + params.pointer += ((threadblock_offset[1] + thread_offset[1]) * params.stride_h + + (threadblock_offset[2] + thread_offset[2])); - // Move the pointer to the residue. - params.pointer += params.move_to_residue_offset; + } - // The dimensions of the tile. - int const kH = TileTraits_::Tile::kH; - int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize; + /// Increment the pointer in the W dimension. + CUTLASS_HOST_DEVICE void inc_w() { Base::inc_w(); } + /// Increment the pointer in the H dimension. + CUTLASS_HOST_DEVICE void inc_h() { params.pointer += params.inc_h; } + /// Increment the pointer in the D dimension. + CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; } + /// Increment the pointer to move to the next iteration. + CUTLASS_HOST_DEVICE void inc_advance() { params.pointer += params.inc_advance; } - // The unrolling factor. - int const kUnroll = kAdvance == IteratorAdvance::kH ? kH : kW; - - // Clear the predicates for the residue. TODO: We can do something smarter. - int const kResidue = (int)(k % (Index)kUnroll); - if (kResidue > 0) { - residue(kResidue); - } + /// Loads a single fragment element from memory + CUTLASS_HOST_DEVICE void load_element( + typename Base::AccessType& value, int d, int h, int w, int c) const { + int const offset = + ComputeOffsetFromStrides::get(0, 0, w, c); + Load::load(value, params.pointer, offset); } /// That's the residue! Update the predicates. - CUTLASS_DEVICE void residue(Index k) { + CUTLASS_HOST_DEVICE void residue(Index k) { // The coordinates of the thread. Index block_h = thread_offset[1]; // The contiguous dimension. @@ -375,26 +331,63 @@ struct GemmGlobalIteratorAb } } - /// Rollback to beginning of first tile and initialize predicates. - CUTLASS_DEVICE void rollback() { - params.pointer = stored_pointer; - predicates = stored_predicates; - } - - /// Is the iterator valid? - CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { + /// Is the valid? + CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { int const bit = ComputeOffsetFromShape::get(d, h, w, c); return predicates[bit]; } - /// Offset of an individual lane from the start of the tile - Coord<4> thread_offset; - /// The parameters - Params params; - /// The pointer. - typename Base::Scalar const* stored_pointer; - /// The predicates. - PredicateVector predicates, stored_predicates; + /// Adds a vector offset to the iterator + CUTLASS_HOST_DEVICE GemmGlobalIteratorAb & operator+=(Coord<3> const &offset) { + + long long _offset = offset.template dot( + make_Coord(params.stride_d, params.stride_h, params.stride_w) + ); + + params.pointer += _offset; + return *this; + } + + CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; } + + CUTLASS_HOST_DEVICE Index stride_advance(void) { + Index stride = params.stride_h; + if (kAdvance == IteratorAdvance::kW) { + stride = params.stride_w; + } + return stride; + } + + template + CUTLASS_HOST_DEVICE void load_post_increment(Fragment& fragment) { + typename Base::FragmentIterator frag_iterator(fragment); + for (int d = 0; d < Base::Iterations::kD; ++d) { + for (int h = 0; h < Base::Iterations::kH; ++h) { + for (int w = 0; w < Base::Iterations::kW; ++w) { + for (int c = 0; c < Base::Iterations::kC; ++c) { + if (valid(d, h, w, c)) { + load_element( + reinterpret_cast(frag_iterator.at(d, h, w, c)), + d, + h, + w, + c); + } + } + if (w < Base::Iterations::kW - 1) { + inc_w(); + } + } + if (h < Base::Iterations::kH - 1) { + inc_h(); + } + } + if (d < Base::Iterations::kD - 1) { + inc_d(); + } + } + inc_advance(); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -433,6 +426,8 @@ struct GemmGlobalIteratorCd : public TileIteratorBasepointer = pointer; + // Stride per batch + stride_d = batch_stride; // Each column of the matrix. - stride_h = TileTraits_::ThreadsDelta::kH * ld; + stride_h = TileTraits_::ThreadsDelta::kH * ldm; // Each thread output 1 column per iteration. The stride between columns is given by the // number of scalars that are loaded per LDS for B. - inc_h = ld * TileTraits_::kStrideH; + inc_h = ldm * TileTraits_::kStrideH; inc_advance = - (ld - ld * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w; + (ldm - ldm * TileTraits_::kStrideH * (Base::Iterations::kH - 1)) + epilogue_stride_w; predicate_offset = bound; predicate_inc_h = TileTraits_::kStrideH; @@ -464,75 +465,173 @@ struct GemmGlobalIteratorCd : public TileIteratorBase thread_offset; + /// The predicates for the row. + cutlass::PredicateVector predicates; /// Ctor. - CUTLASS_DEVICE GemmGlobalIteratorCd() {} + CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const& _params, + const Coord<3>& bounds, + const Coord<3>& block_offset, + ThreadOffset thread_offset_func = ThreadOffset()) + : params(_params) { + thread_offset = thread_offset_func(); + // Prepare the vector of predicates. + for (int i = 0; i < Base::Iterations::kW; ++i) { + predicates.set(i, thread_offset[2] + i * Base::Delta::kW < bounds[2]); + } + } /// Ctor. - CUTLASS_DEVICE GemmGlobalIteratorCd(Params const& params, - const Coord<3>& bounds, - const Coord<3>& block, - int offset = 0, - int pred_offset = 0, - ThreadOffset thread_offset_func = ThreadOffset()) - : params(params) { + CUTLASS_HOST_DEVICE GemmGlobalIteratorCd(Params const& _params, + const Coord<3>& bounds, + const Coord<3>& block, + int offset = 0, + int pred_offset = 0, + ThreadOffset thread_offset_func = ThreadOffset()) + : params(_params) { thread_offset = thread_offset_func(); // Each warp works on a different column of the tile. int const h = thread_offset[1] + block[1]; // Each lane writes a different element. int const w = thread_offset[2] + block[2]; // Setup the pointer. - this->params.pointer += ((h * params.stride_h + w) + offset); + params.pointer += ((h * params.stride_h + w) + offset); // Prepare the vector of predicates. for (int i = 0; i < Base::Iterations::kW; ++i) { predicates.set(i, w + i * Base::Delta::kW < bounds[2]); } - this->params.predicate_offset -= (h + pred_offset); - } - - /// The accessor. - CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const { - int const imm = - ComputeOffsetFromStrides::get(0, 0, w, c); - Load::load(value, params.pointer, imm); + params.predicate_offset -= (h + pred_offset); } /// Increment the pointer in the C dimension. - CUTLASS_DEVICE void inc_c() {} + CUTLASS_HOST_DEVICE void inc_c() {} /// Increment the pointer in the W dimension. - CUTLASS_DEVICE void inc_w() {} + CUTLASS_HOST_DEVICE void inc_w() {} /// Increment the pointer in the H dimension. - CUTLASS_DEVICE void inc_h() { + CUTLASS_HOST_DEVICE void inc_h() { params.pointer += params.inc_h; params.predicate_offset -= params.predicate_inc_h; } /// Increment the pointer in the D dimension. - CUTLASS_DEVICE void inc_d() {} + CUTLASS_HOST_DEVICE void inc_d() {} /// Increment the pointer to move to the next iteration. - CUTLASS_DEVICE void inc_advance() { + CUTLASS_HOST_DEVICE void inc_advance() { params.pointer += params.inc_advance; - this->params.predicate_offset -= params.predicate_inc_advance; + params.predicate_offset -= params.predicate_inc_advance; } - /// The accessor. - CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) { - int const imm = - ComputeOffsetFromStrides::get(0, 0, w, c); - Store::store( - value, params.pointer, imm); + /// Adds a vector offset to the iterator + CUTLASS_HOST_DEVICE GemmGlobalIteratorCd & operator+=(Coord<3> const &offset) { + long long _offset = offset.template dot( + make_Coord(params.stride_d, params.stride_h, 1) + ); + params.pointer += _offset; + return *this; } - /// Test the validity of the iterator. - CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { + /// Loads a single fragment element from memory. + CUTLASS_HOST_DEVICE void load_element( + typename Base::AccessType& value, int d, int h, int w, int c) const { + int const offset = + ComputeOffsetFromStrides::get(d, h, w, c); + Load::load(value, params.pointer, offset); + } + + /// Stores a single fragment element into memory. + CUTLASS_HOST_DEVICE void store_element( + typename Base::AccessType const& value, int d, int h, int w, int c) { + int const offset = + ComputeOffsetFromStrides::get(d, h, w, c); + Store::store(value, params.pointer, offset); + } + + /// Test the validity of the + CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return predicates.at(w) && params.predicate_offset > 0; } - /// The predicates for the row. - cutlass::PredicateVector predicates; + /// add pointer offset + CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; } + + /// Loads and increments iterator + template + CUTLASS_HOST_DEVICE void load_post_increment(Fragment& fragment) { + typename Base::FragmentIterator frag_iterator(fragment); + for (int d = 0; d < Base::Iterations::kD; ++d) { + for (int h = 0; h < Base::Iterations::kH; ++h) { + for (int w = 0; w < Base::Iterations::kW; ++w) { + for (int c = 0; c < Base::Iterations::kC; ++c) { + if (valid(d, h, w, c)) { + load_element( + reinterpret_cast(frag_iterator.at(d, h, w, c)), + d, + h, + w, + c); + } + } + if (w < Base::Iterations::kW - 1) { + inc_w(); + } + } + if (h < Base::Iterations::kH - 1) { + inc_h(); + } + } + if (d < Base::Iterations::kD - 1) { + inc_d(); + } + } + inc_advance(); + } + + template + CUTLASS_HOST_DEVICE void store_post_increment(Fragment& fragment) { + typename Base::FragmentIterator frag_iterator(fragment); + for (int d = 0; d < Base::Iterations::kD; ++d) { + for (int h = 0; h < Base::Iterations::kH; ++h) { + for (int w = 0; w < Base::Iterations::kW; ++w) { + for (int c = 0; c < Base::Iterations::kC; ++c) { + if (valid(d, h, w, c)) { + store_element( + reinterpret_cast(frag_iterator.at(d, h, w, c)), + d, + h, + w, + c); + } + } + if (w < Base::Iterations::kW - 1) { + inc_w(); + } + } + if (h < Base::Iterations::kH - 1) { + inc_h(); + } + } + if (d < Base::Iterations::kD - 1) { + inc_d(); + } + } + inc_advance(); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass/gemm/gemm_operand.h b/cutlass/gemm/gemm_operand.h index 737f993f0..2b4dcdc91 100644 --- a/cutlass/gemm/gemm_operand.h +++ b/cutlass/gemm/gemm_operand.h @@ -28,9 +28,9 @@ */ #pragma once -#include -#include -#include +#include "cutlass/matrix_traits.h" +#include "cutlass/reshape_tile.h" +#include "cutlass/util/platform.h" namespace cutlass { namespace gemm { diff --git a/cutlass/gemm/gemm_shared_stream.h b/cutlass/gemm/gemm_shared_stream.h index c6ff7bd97..df20bd6ca 100644 --- a/cutlass/gemm/gemm_shared_stream.h +++ b/cutlass/gemm/gemm_shared_stream.h @@ -28,7 +28,8 @@ */ #pragma once -#include +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/gemm_shared_tile.h" namespace cutlass { namespace gemm { @@ -56,6 +57,11 @@ struct SharedLoadStream { ""); /// The output fragment. typedef TransformedFragment Fragment; + /// Scalar data type + typedef typename Iterator::Scalar Scalar; + + /// Reference type to a tensor + typedef TensorRef TensorRef; /// The params. struct Params { @@ -73,29 +79,38 @@ struct SharedLoadStream { CUTLASS_DEVICE SharedLoadStream() {} /// Ctor. - CUTLASS_DEVICE SharedLoadStream(Params const ¶ms, SharedStorage &shared_storage) { - this->initialize(params, shared_storage); + CUTLASS_DEVICE SharedLoadStream(Params const ¶ms, TensorRef const &ref) { + this->initialize(params, ref); } /// Initialize the stream. - CUTLASS_DEVICE void initialize(Params const ¶ms, SharedStorage &shared_storage) { + CUTLASS_DEVICE void initialize(Params const ¶ms, TensorRef const &ref) { // The iterator. - iterator = Iterator(params.iterator, shared_storage); + iterator = Iterator(params.iterator, ref.data()); // The transformer. transformer = Transformer(); } /// Load the data from shared memory to the fetch fragment. - CUTLASS_DEVICE void copy(FetchedFragment &fetched) { shared_iterator_load(iterator, fetched); } + CUTLASS_DEVICE void copy() { iterator.load_post_increment(fetched[0]); } /// Load the data from shared memory to the fetch fragment. - CUTLASS_DEVICE void copy(int d, FetchedFragment &fetched) { - shared_iterator_load(iterator, fetched, d); - } + CUTLASS_DEVICE void copy(int step) { iterator.load(fetched[step % 2], step); } /// Commit the data. - CUTLASS_DEVICE void commit(FetchedFragment &fetched, TransformedFragment &transformed) { - transformer.transform(fetched, transformed); + CUTLASS_DEVICE void commit() { transformer.transform(fetched[0], transformed[0]); } + + /// Commit the data. + CUTLASS_DEVICE void commit(int step) { + transformer.transform(fetched[step % 2], transformed[step % 2]); + } + + /// Returns the fragment for the given step + CUTLASS_DEVICE TransformedFragment &fragment(int step = 0) { return transformed[step % 2]; } + + /// Returns the fragment for the given step + CUTLASS_DEVICE TransformedFragment const &fragment(int step = 0) const { + return transformed[step % 2]; } /// Increment the stage. @@ -103,8 +118,12 @@ struct SharedLoadStream { /// The iterator. Iterator iterator; + /// Fetched fragment + FetchedFragment fetched[2]; /// The transformer. Transformer transformer; + /// Transformed fragment + TransformedFragment transformed[2]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass/gemm/gemm_shared_tile.h b/cutlass/gemm/gemm_shared_tile.h index 7c61e0229..78fb1f205 100644 --- a/cutlass/gemm/gemm_shared_tile.h +++ b/cutlass/gemm/gemm_shared_tile.h @@ -27,7 +27,7 @@ */ #pragma once -#include +#include "cutlass/gemm/gemm_operand.h" namespace cutlass { namespace gemm { diff --git a/cutlass/gemm/gemm_stream_pair.h b/cutlass/gemm/gemm_stream_pair.h new file mode 100644 index 000000000..0a6df15ed --- /dev/null +++ b/cutlass/gemm/gemm_stream_pair.h @@ -0,0 +1,251 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a pair of GEMM tile streams +*/ +#pragma once + +#include "cutlass/convert.h" +#include "cutlass/matrix_traits.h" +#include "cutlass/reshape_tile.h" +#include "cutlass/tile_allocation.h" +#include "cutlass/tile_iterator.h" + +#include "cutlass/gemm/clear_accumulators.h" +#include "cutlass/gemm/gemm_config.h" +#include "cutlass/gemm/gemm_global_stream.h" +#include "cutlass/gemm/gemm_operand.h" +#include "cutlass/gemm/gemm_shared_stream.h" +#include "cutlass/gemm/threadblock_swizzle.h" + +namespace cutlass { +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Collect the global load streams for multiplicands. +template +struct GlobalLoadStreamPair { + // + // Type definitions + // + + /// Stream for A multiplicand + typedef StreamA_ StreamA; + + /// Stream for B multiplicand + typedef StreamB_ StreamB; + + /// Parameters object + struct Params { + /// Parameters object for StreamA + typename StreamA::Params stream_a; + + /// Parameters object for StreamB + typename StreamB::Params stream_b; + + /// Default constructor + CUTLASS_HOST_DEVICE + Params() {} + + /// Constructs a global load stream pair Params object + CUTLASS_HOST_DEVICE + Params(typename StreamA::Params const &_params_A, typename StreamB::Params const &_params_B) + : stream_a(_params_A), stream_b(_params_B) {} + }; + + /// Assumes the A stream defines the index type + typedef typename StreamA::Index Index; + + /// Shared memory allocation for threadblock-scoped GEMM tile + typedef ZipTileAllocation + ThreadblockTileStorage; + + /// ZipTensorRef to threadblock tiles + typedef typename ThreadblockTileStorage::TensorRef ThreadblockTileRef; + + /// Defines a structure containing shared storage for each pair + struct SharedStorage { + typename StreamA::SharedStorage stream_a; + typename StreamB::SharedStorage stream_b; + }; + + // + // Data members + // + + /// Stream for A multiplicand + StreamA stream_a; + + /// Stream for B multiplicand + StreamB stream_b; + + // + // Methods + // + + /// Ctor. + CUTLASS_DEVICE GlobalLoadStreamPair(Params const ¶ms, + SharedStorage &shared_storage, + ThreadblockTileRef const &threadblock_tile_ref, + Coord<3> const &bounds, + Coord<3> const &block_offset = make_Coord(0, 0, 0)) + : stream_a(params.stream_a, + shared_storage.stream_a, + threadblock_tile_ref.first, + bounds, + block_offset), + stream_b(params.stream_b, + shared_storage.stream_b, + threadblock_tile_ref.second, + bounds, + block_offset) {} + + CUTLASS_DEVICE + GlobalLoadStreamPair & operator+=(Coord<3> const offset) { + stream_a += offset; + stream_b += offset; + return *this; + } + + /// Trigger the copies from shared memory to registers. + CUTLASS_DEVICE void copy() { + stream_a.copy(); + stream_b.copy(); + } + + /// Commit the data. + CUTLASS_DEVICE void commit() { + stream_a.commit(); + stream_b.commit(); + } + + /// Execute the residue code. + CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) { + stream_a.residue(k, skip_clear); + stream_b.residue(k, skip_clear); + } + + /// Move to residue. + CUTLASS_DEVICE void move_to_residue(Index k, Index kTileK) { + if (kResidueInProlog_) { + stream_a.move_to_residue(k, kTileK); + stream_b.move_to_residue(k, kTileK); + } else if (k < kTileK) { + residue(k, true); + } + } + + /// Rollback to beginning of first tile. + CUTLASS_DEVICE void rollback(bool kRollback) { + if (kResidueInProlog_ && kRollback) { + stream_a.rollback(); + stream_b.rollback(); + } + } +}; + +/// Collect the global load streams for multiplicands. +template +struct SharedStreamPair { + // + // Type definitions + // + + /// Stream for A multiplicand + typedef StreamA_ StreamA; + + /// Stream for B multiplicand + typedef StreamB_ StreamB; + + /// Parameters object passed to load iterators + struct Params { + /// + typename StreamA::Params stream_a; + + /// + typename StreamB::Params stream_b; + }; + + /// Shared memory allocation for threadblock-scoped GEMM tile + typedef ZipTensorRef + ThreadblockTileRef; + + // + // Data members + // + + /// The stream for A. + StreamA stream_a; + + /// The stream for B. + StreamB stream_b; + + // + // Methods + // + + /// Construct with the composable structure + CUTLASS_DEVICE SharedStreamPair(Params const ¶ms, ThreadblockTileRef const &threadblock_tile_ref) + : stream_a(params.stream_a, threadblock_tile_ref.first), + stream_b(params.stream_b, threadblock_tile_ref.second) {} + + /// Trigger the copies from shared memory to registers. + CUTLASS_DEVICE void copy(int step) { + stream_a.copy(step); + stream_b.copy(step); + } + + /// Commit the data. + CUTLASS_DEVICE void commit(int step) { + stream_a.commit(step); + stream_b.commit(step); + } + + /// The fragment A. + CUTLASS_DEVICE + typename StreamA::TransformedFragment const &fragment_a(int step) const { + return stream_a.fragment(step); + } + + /// The fragment B. + CUTLASS_DEVICE + typename StreamB::TransformedFragment const &fragment_b(int step) const { + return stream_b.fragment(step); + } + + /// Increment the stage. + CUTLASS_DEVICE void inc_stage() { + stream_a.inc_stage(); + stream_b.inc_stage(); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/gemm_traits.h b/cutlass/gemm/gemm_traits.h index cb57c4d5c..fd6efb466 100644 --- a/cutlass/gemm/gemm_traits.h +++ b/cutlass/gemm/gemm_traits.h @@ -27,117 +27,27 @@ */ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "cutlass/convert.h" +#include "cutlass/matrix_traits.h" +#include "cutlass/reshape_tile.h" +#include "cutlass/tile_allocation.h" +#include "cutlass/tile_iterator.h" +#include "cutlass/kernel_launch.h" +#include "cutlass/gemm/clear_accumulators.h" +#include "cutlass/gemm/gemm_config.h" +#include "cutlass/gemm/gemm_desc.h" +#include "cutlass/gemm/gemm_stream_pair.h" +#include "cutlass/gemm/gemm_global_stream.h" +#include "cutlass/gemm/gemm_operand.h" +#include "cutlass/gemm/gemm_shared_stream.h" +#include "cutlass/gemm/threadblock_swizzle.h" +#include "cutlass/gemm/gemm.h" namespace cutlass { namespace gemm { //////////////////////////////////////////////////////////////////////////////////////////////////// -template < - /// The scalar type for A. - typename ScalarA_, - /// The scalar type for B. - typename ScalarB_, - /// The scalar type for C. - typename ScalarC_, - /// The scalar type for D. - typename ScalarD_, - /// The output tile size for the GEMM KxNxM. - typename OutputTile_, - /// The functor to do the math. - typename MultiplyAdd_, - /// The number of scalars per LDG for A. - int kScalarsPerLdgA_, - /// The number of scalars per STS for A. - int kScalarsPerStsA_, - /// The number of scalars per LDG for A. - int kScalarsPerLdsA_, - /// The number of scalars per LDG for B. - int kScalarsPerLdgB_, - /// The number of scalars per STS for B. - int kScalarsPerStsB_, - /// The number of scalars per LDS for B. - int kScalarsPerLdsB_, - /// The number of scalars per LDG for C and STG for D. - int kScalarsPerLdgCAndStgD_, - /// The number of scalars per STS for D. - int kScalarsPerStsD_, - /// The number of scalars per LDS for D. - int kScalarsPerLdsD_, - /// The number of stages in shared memory to do single/double/triple-buffering. - int kStages_, - /// Do we do the residue in the prologue? - bool kResidueInPrologue_ = false> - -struct GemmConfig { - // - /// The scalar for A. - typedef ScalarA_ ScalarA; - /// The scalar for B. - typedef ScalarB_ ScalarB; - /// The scalar for C. - typedef ScalarC_ ScalarC; - /// The scalar for D. - typedef ScalarD_ ScalarD; - - /// The tile. - typedef OutputTile_ OutputTile; - /// The functor to do D = A*B + C. - typedef MultiplyAdd_ MultiplyAdd; - /// The shape of the instruction. - typedef typename MultiplyAdd::InstructionShape InstructionShape; - /// The number of accumulators per warp. - typedef typename MultiplyAdd::AccumulatorsPerWarp AccumulatorsPerWarp; - /// The accumulators. - typedef typename MultiplyAdd::Accumulators Accumulators; - - /// The number of warps. - typedef typename ShapeDiv::Shape Warps; - /// The default warp size (32 threads per warp). - static int const kWarpSize = cutlass::kWarpSize; - /// The numnber of threads. - static int const kThreads = ShapeCount::kCount * kWarpSize; - - /// The number of scalars per LDG/STS/LDS for A. - static int const kScalarsPerLdgA = kScalarsPerLdgA_; - static int const kScalarsPerStsA = kScalarsPerStsA_; - static int const kScalarsPerLdsA = kScalarsPerLdsA_; - - /// The number of scalars per LDG/STS/LDS for B. - static int const kScalarsPerLdgB = kScalarsPerLdgB_; - static int const kScalarsPerStsB = kScalarsPerStsB_; - static int const kScalarsPerLdsB = kScalarsPerLdsB_; - - /// The number of scalars per LDG for C. - static int const kScalarsPerLdgC = kScalarsPerLdgCAndStgD_; - - /// The number of scalars per STS/LDS/STG for D. - static int const kScalarsPerStgD = kScalarsPerLdgCAndStgD_; - static int const kScalarsPerStsD = kScalarsPerStsD_; - static int const kScalarsPerLdsD = kScalarsPerLdsD_; - - /// The number of accumulators that are going to be fed from one LDS A/B. - static int const kAccumulatorsPerLdsA = kScalarsPerLdsA / InstructionShape::kD; - static int const kAccumulatorsPerLdsB = kScalarsPerLdsB / InstructionShape::kD; - - /// The number of stages in shared memory to implement double, triple, more-buffering. - static int const kStages = kStages_; - - /// Do we do the residue in the prologue? - static bool const kResidueInPrologue = kResidueInPrologue_; -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template struct GemmTileTraitsHelperA {}; @@ -416,60 +326,6 @@ struct GemmTileTraitsHelperB { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct GemmResidue { - /// Move to residue portion. - template - static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a, - typename GemmTraits_::GlobalLoadStreamB& stream_b, - typename GemmTraits_::Index k) { - // The new code path in CUTLASS 1.0.1: We treat the residue in the prologue so we can have - // complete main loops after that. It helps simplify the logic in the main loop. - if (kIsPrologue) { - stream_a.move_to_residue(k); - stream_b.move_to_residue(k); - } - } - - /// Rollback to beginning of first tile and initialize predicates. - static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a, - typename GemmTraits_::GlobalLoadStreamB& stream_b) { - stream_a.rollback(); - stream_b.rollback(); - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct GemmResidue { - /// Move to residue portion. - template - static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a, - typename GemmTraits_::GlobalLoadStreamB& stream_b, - typename GemmTraits_::Index k) { - // The index. - typedef typename GemmTraits_::Index Index; - // By how much we unroll the main loop. - Index const kUnroll = static_cast(GemmTraits_::OutputTile::kD); - - // Call the residue code. That's the same path as CUTLASS 1.0.0. - if (kIsPrologue && k < kUnroll) { - stream_a.residue(k, true); - stream_b.residue(k, true); - } else if (k <= kUnroll) { - stream_a.residue(k, false); - stream_b.residue(k, false); - } - } - - /// Rollback to beginning of first tile and initialize predicates. - static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a, - typename GemmTraits_::GlobalLoadStreamB& stream_b) {} -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template < /// The GEMM configuration. typename GemmConfig_, @@ -488,27 +344,27 @@ template < /// The index. typename Index_ = int, /// The tool used to clear accumulators. - typename ClearAccumulators_ = ClearAccumulators > + typename ClearAccumulators_ = ClearAccumulators > struct GemmTraits { - /// This class. + /// This traits typedef GemmTraits - This_; + GlobalLoadStreamA_, + GlobalLoadStreamB_, + SharedLoadStreamA_, + SharedLoadStreamB_, + Epilogue_, + BlockSwizzle_, + Index_, + ClearAccumulators_> This_; + + /// The struct that consumes this Traits + typedef typename cutlass::gemm::Gemm KernelClass; /// The configuration. typedef GemmConfig_ GemmConfig; /// The output tile. typedef typename GemmConfig::OutputTile OutputTile; - /// Is the residue treated in the prologue? - static bool const kResidueInPrologue = GemmConfig::kResidueInPrologue; /// The stream to load A from global memory to shared memory. typedef GlobalLoadStreamA_ GlobalLoadStreamA; @@ -544,18 +400,30 @@ struct GemmTraits { /// Clear the accumulators. typedef ClearAccumulators_ ClearAccumulators; - /// The params. - struct Params { - /// The dimensions of the GEMM. - Index m, n, k; - /// The params for the A stream. - typename GlobalLoadStreamA::Params global_stream_a; - /// The params for the B stream. - typename GlobalLoadStreamB::Params global_stream_b; - /// The params for the A stream from shared memory. - typename SharedLoadStreamA::Params shared_stream_a; - /// The params for the B stream from shared memory. - typename SharedLoadStreamB::Params shared_stream_b; + /// Assemble the global load streams for A/B. + typedef GlobalLoadStreamPair + GlobalLoadStream; + + /// Memory needed to store the threadblock-scoped GEMM tile + typedef typename GlobalLoadStream::ThreadblockTileStorage ThreadblockTileStorage; + + /// Assemble the shared load streams for A/B. + typedef SharedStreamPair SharedStream; + + /// Parameters object constructable on the host. + struct Params : public KernelLaunchConfiguration { + + /// GEMM problem size + GemmCoord problem_size; + + /// Parameters object for the global load stream + typename GlobalLoadStream::Params global_to_shared_stream; + + /// Parameters object for the shared load stream + typename SharedStream::Params shared_stream; + /// The params for the epilogue. typename Epilogue::Params epilogue; @@ -563,21 +431,36 @@ struct GemmTraits { template CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) { // Set the problem size. - this->m = desc.m; - this->n = desc.n; - this->k = desc.k; + problem_size = desc.problem_size; - // Initialize the iterator for A. - int error_code = - global_stream_a.initialize(desc, reinterpret_cast(desc.d_a), desc.lda); + // Compute grid dimensions + BlockSwizzle block_swizzle; + this->block = dim3(GemmConfig::kThreads); + this->grid = block_swizzle.get_grid_layout( + problem_size, + make_Coord_from_shape()); + // Compute offset to residue. + Index gemm_k = problem_size[0]; + Index offset_to_residue = (gemm_k % OutputTile::kD) ? gemm_k - (gemm_k % OutputTile::kD) : 0; + + // Initialize parameters objects for + int error_code = global_to_shared_stream.stream_a.initialize( + desc.A.data(), + desc.batch_stride_A, + desc.A.leading_dim(), + offset_to_residue + ); if (error_code) { return error_code; } - // Initialize the iterator for B. - error_code = - global_stream_b.initialize(desc, reinterpret_cast(desc.d_b), desc.ldb); + error_code = global_to_shared_stream.stream_b.initialize( + desc.B.data(), + desc.batch_stride_B, + desc.B.leading_dim(), + offset_to_residue + ); if (error_code) { return error_code; @@ -586,24 +469,81 @@ struct GemmTraits { // The epilogue. return epilogue.initialize(desc); } - }; - // The storage for A. - template - union StreamSharedStorage { - // The storage needed by the global stream. - typename GlobalLoadStream_::SharedStorage global; - // The storage needed by the shared stream. - typename SharedLoadStream_::SharedStorage shared; + /// Helper to construct a GEMM params using a BLAS-like API + CUTLASS_HOST_DEVICE int initialize(Index m, + Index n, + Index k, + typename Epilogue::Scalar alpha, + ScalarA const* d_a, + Index lda, + ScalarB const* d_b, + Index ldb, + typename Epilogue::Scalar beta, + ScalarC const* d_c, + Index ldc, + ScalarD* d_d, + Index ldd) { + GemmDesc desc( + GemmCoord(k, n, m, 1), + alpha, + TensorRef(d_a, lda), + TensorRef(d_b, ldb), + beta, + TensorRef(d_c, ldc), + TensorRef(d_d, ldd) + ); + + return this->initialize(desc); + } + + /// Helper to construct a batched GEMM params + CUTLASS_HOST_DEVICE int initialize(Index m, + Index n, + Index k, + typename Epilogue::Scalar alpha, + ScalarA const* d_a, + Index lda, + long long int batch_stride_A, + ScalarB const* d_b, + Index ldb, + long long int batch_stride_B, + typename Epilogue::Scalar beta, + ScalarC const* d_c, + Index ldc, + long long int batch_stride_C, + ScalarD* d_d, + Index ldd, + long long int batch_stride_D, + Index batch_count) { + + GemmDesc desc( + GemmCoord(k, n, m, batch_count), + alpha, + TensorRef(d_a, lda), + batch_stride_A, + TensorRef(d_b, ldb), + batch_stride_B, + beta, + TensorRef(d_c, ldc), + batch_stride_C, + TensorRef(d_d, ldd), + batch_stride_D + ); + + return this->initialize(desc); + } }; // The storage for the main loop + prologue. struct MainLoopSharedStorage { - // The storage to shuffle the A matrix in shared memory. - StreamSharedStorage stream_a; - // The storage to shuffle the B matrix in shared memory. - StreamSharedStorage stream_b; - // The storage to clear the accumulators if needed. + /// Stores the threadblock tile + ThreadblockTileStorage threadblock_tile; + + /// Storage for GEMM global stream + typename GlobalLoadStream::SharedStorage global_to_shared_stream; + + /// Storage for clearing accumulators typename ClearAccumulators::SharedStorage clear; }; @@ -615,108 +555,18 @@ struct GemmTraits { typename Epilogue::SharedStorage epilogue; }; - /// Assemble the global load streams for A/B. - struct GlobalLoadStream { - /// Ctor. - CUTLASS_DEVICE GlobalLoadStream(Params const& params, - SharedStorage& shared_storage, - dim3 const& block) - : stream_a(params.global_stream_a, - shared_storage.main_loop.stream_a.global, - cutlass::make_Coord(0, params.k, params.m), - cutlass::make_Coord(0, 0, block.x)), - stream_b(params.global_stream_b, - shared_storage.main_loop.stream_b.global, - cutlass::make_Coord(0, params.k, params.n), - make_Coord(0, 0, block.y)) {} - - /// Trigger the copies from shared memory to registers. - CUTLASS_DEVICE void copy() { - stream_a.copy(); - stream_b.copy(); - } - - /// Commit the data. - CUTLASS_DEVICE void commit() { - stream_a.commit(); - stream_b.commit(); - } - - /// Move to residue portion. - template - CUTLASS_DEVICE void move_to_residue(Index k) { - GemmResidue::move_to_residue(stream_a, stream_b, k); - } - - /// Rollback to beginning of first tile and initialize predicates. - CUTLASS_DEVICE void rollback() { GemmResidue::rollback(stream_a, stream_b); } - - /// The stream for A. - GlobalLoadStreamA stream_a; - /// The stream for B. - GlobalLoadStreamB stream_b; - }; - - /// Assemble the shared load stream for A/B. - struct SharedLoadStream { - /// Ctor. - CUTLASS_DEVICE SharedLoadStream(Params const& params, SharedStorage& shared_storage) { - stream_a.initialize(params.shared_stream_a, shared_storage.main_loop.stream_a.shared); - stream_b.initialize(params.shared_stream_b, shared_storage.main_loop.stream_b.shared); - } - - /// Trigger the copies from shared memory to registers. - CUTLASS_DEVICE void copy(int step) { - stream_a.copy(step, fetched_a[step % 2]); - stream_b.copy(step, fetched_b[step % 2]); - } - - /// Commit the data. - CUTLASS_DEVICE void commit(int step) { - stream_a.commit(fetched_a[step % 2], transformed_a[step % 2]); - stream_b.commit(fetched_b[step % 2], transformed_b[step % 2]); - } - - /// The fragment A. - CUTLASS_DEVICE typename SharedLoadStreamA::Fragment const& fragment_a(int step) const { - return transformed_a[step % 2]; - } - - /// The fragment B. - CUTLASS_DEVICE typename SharedLoadStreamB::Fragment const& fragment_b(int step) const { - return transformed_b[step % 2]; - } - - /// Increment the stage. - CUTLASS_DEVICE void inc_stage() { - stream_a.inc_stage(); - stream_b.inc_stage(); - } - - /// The stream for A. - SharedLoadStreamA stream_a; - /// The fragments to fetch A. - typename SharedLoadStreamA::FetchedFragment fetched_a[2]; - /// The fragments to transform A. - typename SharedLoadStreamA::TransformedFragment transformed_a[2]; - /// The stream for B. - SharedLoadStreamB stream_b; - /// The fragments to fetch B. - typename SharedLoadStreamB::FetchedFragment fetched_b[2]; - /// The fragments to transform B. - typename SharedLoadStreamB::TransformedFragment transformed_b[2]; - }; - /// The memory fence for shared loads. static CUTLASS_DEVICE void shared_load_fence(bool in_loop) { if (SharedLoadStreamA::Iterator::kRequiresLoadFence || SharedLoadStreamB::Iterator::kRequiresLoadFence) { - __syncthreads(); + __syncthreads(); } } /// The memory fence for shared stores. - static CUTLASS_DEVICE void shared_store_fence(bool in_loop) { __syncthreads(); } + static CUTLASS_DEVICE void shared_store_fence(bool in_loop) { + __syncthreads(); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -735,7 +585,10 @@ struct SimplifiedGemmTraitsHelper { MemorySpace::kShared> SharedStoreIteratorA; /// The stream to load A from global memory to shared memory. - typedef GlobalLoadStream + typedef GlobalLoadStream GlobalLoadStreamA; /// The global iterator to load B from global memory. @@ -750,7 +603,10 @@ struct SimplifiedGemmTraitsHelper { MemorySpace::kShared> SharedStoreIteratorB; /// The stream to load B from global memory to shared memory. - typedef GlobalLoadStream + typedef GlobalLoadStream GlobalLoadStreamB; /// The iterator to load A from shared memory. diff --git a/cutlass/gemm/hgemm_global_tile.h b/cutlass/gemm/hgemm_global_tile.h index f14dbb311..9d5ffe850 100644 --- a/cutlass/gemm/hgemm_global_tile.h +++ b/cutlass/gemm/hgemm_global_tile.h @@ -29,10 +29,10 @@ */ #pragma once -#include -#include -#include -#include +#include "cutlass/coord.h" +#include "cutlass/gemm/gemm_global_tile.h" +#include "cutlass/matrix_traits.h" +#include "cutlass/reshape_tile.h" namespace cutlass { namespace gemm { @@ -63,14 +63,14 @@ struct HgemmCrosswiseGlobalTileTraits : public GemmGlobalTileTraits< /// The threads. typedef typename Base::Threads Threads; /// The threads strides. - typedef Shape<1, 2, Base::Tile::kC> ThreadsDelta; + typedef Shape<1, 2, Base::VectorizedTile::kC> ThreadsDelta; /// The strides in each dimension between different loads/stores. typedef Shape Delta; /// The number of iterations needed to load/store the tile. - typedef Shape + Base::VectorizedTile::kW / Base::Threads::kW, + Base::VectorizedTile::kC / Base::kAccessSize> Iterations; /// Computes the thread offset in (H, W) based on thread ID struct ThreadOffset { diff --git a/cutlass/gemm/hgemm_multiply_add.h b/cutlass/gemm/hgemm_multiply_add.h index ebbdd06e8..7217d82c5 100644 --- a/cutlass/gemm/hgemm_multiply_add.h +++ b/cutlass/gemm/hgemm_multiply_add.h @@ -28,9 +28,9 @@ */ #pragma once -#include +#include "cutlass/fragment.h" -#include +#include "cutlass/gemm/thread_multiply_add.h" namespace cutlass { namespace gemm { @@ -38,16 +38,18 @@ namespace gemm { //////////////////////////////////////////////////////////////////////////////////////////////////// /// Template performing matrix multiply-add operation within a thread -template -struct ThreadMultiplyAdd { +template +struct ThreadMultiplyAdd { /// The shape of the instruction. typedef Shape<1, 1, 2, 1> InstructionShape; /// The number of accumulators per thread. - typedef AccumulatorsPerThread_ AccumulatorsPerThread; + typedef ThreadGemmShape_ ThreadGemmShape; + /// Aliased for compatibility. Will be removed for CUTLASS v2.0. + typedef ThreadGemmShape AccumulatorsPerThread; /// The number of threads per warp. typedef ThreadsPerWarp_ ThreadsPerWarp; /// The number of accumulators per warp. - typedef typename ShapeMul::Shape AccumulatorsPerWarp; + typedef typename ShapeMul::Shape AccumulatorsPerWarp; /// The type for A. typedef half ScalarA; /// The fragment for A. @@ -88,9 +90,9 @@ struct ThreadMultiplyAdd -#include +#include "cutlass/fragment.h" namespace cutlass { namespace gemm { diff --git a/cutlass/gemm/hgemm_traits.h b/cutlass/gemm/hgemm_traits.h index b08645bf4..2261bb4b3 100644 --- a/cutlass/gemm/hgemm_traits.h +++ b/cutlass/gemm/hgemm_traits.h @@ -27,18 +27,18 @@ */ #pragma once -#include -#include +#include "cutlass/convert.h" +#include "cutlass/reshape_tile.h" -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_epilogue.h" +#include "cutlass/gemm/gemm_epilogue_traits.h" +#include "cutlass/gemm/gemm_global_tile.h" +#include "cutlass/gemm/gemm_shared_tile.h" +#include "cutlass/gemm/gemm_traits.h" +#include "cutlass/gemm/hgemm_global_tile.h" +#include "cutlass/gemm/hgemm_multiply_add.h" +#include "cutlass/gemm/hgemm_swizzle.h" namespace cutlass { namespace gemm { @@ -48,46 +48,52 @@ namespace gemm { template < /// The tile size for the GEMM KxNxM. typename OutputTile_, - /// The number of accumulators per thread. - typename AccumulatorsPerThread_, + /// Tile size for thread-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_, /// The number of scalars per LDG for A. int kScalarsPerLdgA_ = 2, /// The number of scalars per LDG for B. int kScalarsPerLdgB_ = 2> -struct HgemmConfig - : public GemmConfig< - /// The scalar type for A. - half, - /// The scalar type for B. - half, - /// The scalar type for C. - half, - /// The scalar type for D. - half, - /// The tile size for the GEMM KxNxM. - OutputTile_, - /// The functor to do the math in the main loop. - ThreadMultiplyAdd, half, half, half>, - /// The number of scalars per LDG for A. - kScalarsPerLdgA_, - /// The number of scalars per STS for A. - kScalarsPerLdgA_, - /// The number of scalars per LDS for A. - 8, - /// The number of scalars per LDG for B. - kScalarsPerLdgB_, - /// The number of scalars per STS for B. - kScalarsPerLdgB_, - /// The number of scalars per LDS for B. - 8, - /// The number of scalars per LDG for C and STG for D. - 2, - /// The number of scalars per STS for D. - 8, - /// The number of scalars per LDS for D. - 2, - /// The number of stages in shared memory. - 2> {}; +struct HgemmConfig : public GemmConfig< + /// The scalar type for A. + half, + /// The scalar type for B. + half, + /// The scalar type for C. + half, + /// The scalar type for D. + half, + /// The tile size for the GEMM KxNxM. + OutputTile_, + /// The functor to do the math in the main loop. + ThreadMultiplyAdd, half, half, half>, + /// The number of scalars per LDG for A. + kScalarsPerLdgA_, + /// The number of scalars per STS for A. + kScalarsPerLdgA_, + /// The number of scalars per LDS for A. + 8, + /// The number of scalars per LDG for B. + kScalarsPerLdgB_, + /// The number of scalars per STS for B. + kScalarsPerLdgB_, + /// The number of scalars per LDS for B. + 8, + /// The number of scalars per LDG for C and STG for D. + 2, + /// The number of scalars per STS for D. + 8, + /// The number of scalars per LDS for D. + 2, + /// The number of stages in shared memory. + 2, + /// kResidueSeparate + false, + /// kResidueInPrologue + true, + /// kLaunchBounds + false + > {}; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -147,7 +153,6 @@ struct HgemmTileTraitsHelperA GemmConfig_::kScalarsPerLdgA> GlobalTileTraits; - /// The skew. static int const kSkewA = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2; /// The traits class to build the iterator to store data to shared memory for A^T. @@ -215,7 +220,6 @@ struct HgemmTileTraitsHelperB GemmConfig_::kScalarsPerLdgB> GlobalTileTraits; - /// The skew for B. static int const kSkewB = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2; /// The traits class to build the iterator to store data to shared memory for B^N. @@ -266,8 +270,8 @@ template < typename OutputTile_, /// The functor to do the math in the epilogue. typename EpilogueFunctor_, - /// The number of accumulators per thread. - typename AccumulatorsPerThread_ = Shape<8, 8, 16>, + /// Tile size for thread-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_, /// The number of halfs loaded in one LDG for A. int kScalarsPerLdgA_ = 2, /// The number of halfs loaded in one LDG for B. @@ -276,8 +280,7 @@ template < typename Index_ = int> struct HgemmTraitsHelper { /// The HGEMM config. - typedef HgemmConfig - GemmConfig; + typedef HgemmConfig GemmConfig; /// The GEMM config for A. typedef HgemmTileTraitsHelperA GemmTileTraitsHelperA; /// The GEMM config for B. @@ -296,7 +299,10 @@ struct HgemmTraitsHelper { MemorySpace::kShared> SharedStoreIteratorA; /// The stream to load A from global memory to shared memory. - typedef GlobalLoadStream + typedef GlobalLoadStream GlobalLoadStreamA; /// The iterator to load B from global memory. @@ -312,7 +318,10 @@ struct HgemmTraitsHelper { MemorySpace::kShared> SharedStoreIteratorB; /// The stream to load B from global memory to shared memory. - typedef GlobalLoadStream + typedef GlobalLoadStream GlobalLoadStreamB; /// The iterator to load A from shared memory @@ -354,8 +363,8 @@ template < typename OutputTile_ = Shape<8, 128, 128>, /// The functor to do the math in the epilogue. typename EpilogueFunctor_ = LinearScaling, - /// The number of accumulators per thread. - typename AccumulatorsPerThread_ = Shape<8, 8, 16>, + /// Tile size for warp-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_ = Shape<8, 8, 16>, /// The number of halfs loaded in one LDG for A. int kScalarsPerLdgA_ = 2, /// The number of halfs loaded in one LDG for B. @@ -367,7 +376,7 @@ template < kLayoutB_, OutputTile_, EpilogueFunctor_, - AccumulatorsPerThread_, + ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_, Index_> > diff --git a/cutlass/gemm/igemm_epilogue.h b/cutlass/gemm/igemm_epilogue.h index 0d6998031..2ad24f32c 100644 --- a/cutlass/gemm/igemm_epilogue.h +++ b/cutlass/gemm/igemm_epilogue.h @@ -28,13 +28,13 @@ */ #pragma once -#include -#include -#include -#include -#include -#include -#include +#include "cutlass/convert.h" +#include "cutlass/fragment.h" +#include "cutlass/gemm/gemm_global_stream.h" +#include "cutlass/gemm/gemm_shared_stream.h" +#include "cutlass/gemm/igemm_global_tile.h" +#include "cutlass/reshape_tile.h" +#include "cutlass/tile_iterator.h" namespace cutlass { namespace gemm { @@ -269,8 +269,8 @@ struct IgemmEpilogueTraits : public GemmEpilogueTraits< typename Helper_::SharedStoreIteratorD, // The shared store transformer for D. typename Helper_::SharedStoreTransformerD, - // The iterator to load D from shared memory. - typename Helper_::SharedLoadIteratorD, + // The stream to load D from shared memory. + typename Helper_::SharedLoadStreamD, // The iterations. typename Helper_::Iterations, // The strides between iterations. @@ -294,9 +294,8 @@ struct IgemmEpilogue : public GemmEpilogue { /// Ctor. CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_, typename Base::SharedStorage& shared_storage_, - typename Base::Index m_, - typename Base::Index n_) - : Base(params_, shared_storage_, m_, n_) {} + Coord<3> const& _problem_size) + : Base(params_, shared_storage_, _problem_size) {} }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -309,9 +308,8 @@ struct IgemmEpilogue : public GemmEpilogue const& _problem_size) + : Base(params_, shared_storage_, _problem_size) {} }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass/gemm/igemm_global_tile.h b/cutlass/gemm/igemm_global_tile.h index 3f594ac6a..7a9c1573a 100644 --- a/cutlass/gemm/igemm_global_tile.h +++ b/cutlass/gemm/igemm_global_tile.h @@ -32,9 +32,9 @@ */ #pragma once -#include -#include -#include +#include "cutlass/coord.h" +#include "cutlass/gemm/gemm_global_tile.h" +#include "cutlass/matrix_traits.h" namespace cutlass { namespace gemm { @@ -67,10 +67,10 @@ struct IgemmGlobalTileTraits : public GemmGlobalTileTraits< /// The strides in each dimension between different loads/stores. typedef Shape Delta; /// The number of iterations needed to load/store the tile. - typedef Shape + Base::VectorizedTile::kW / Base::Threads::kW, + Base::VectorizedTile::kC / Base::kAccessSize> Iterations; /// Computes the thread offset in (H, W) based on thread ID @@ -86,24 +86,11 @@ struct IgemmGlobalTileTraits : public GemmGlobalTileTraits< public: /// The threads strides. - typedef Shape<1, 4, Base::Tile::kC> ThreadsDelta; + typedef Shape<1, 4, Base::VectorizedTile::kC> ThreadsDelta; }; //////////////////////////////////////////////////////////////////////////////////////////////////// -/// Deprecated. Please use IgemmGlobalTileTraits instead. - -template -struct IgemmContiguousGlobalTileTraits - : public IgemmGlobalTileTraits {}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - template struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb { /// The base class. @@ -114,11 +101,11 @@ struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb /// Constructor. CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const& _params, const Coord<3>& bounds, - const Coord<3>& block, + const Coord<3>& threadblock_offset, ThreadOffset thread_offset_func = ThreadOffset()) - : Base(_params, bounds, block, thread_offset_func), in_residue_(false), mask_(0xffffffff) { + : Base(_params, bounds, threadblock_offset, thread_offset_func), mask_(0xffffffff) { // The number of elements read in a single iteration. - int const kBlock = TileTraits_::Tile::kW * TileTraits_::kAccessSize; + int const kBlock = TileTraits_::Tile::kW; // The residue. int const kResidue = (int)(bounds[1] % kBlock); @@ -129,28 +116,12 @@ struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb } } - /// The accessor. - CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const { - Base::get(value, d, h, w, c); - if (in_residue_) { - reinterpret_cast(value) &= mask_; - } + CUTLASS_DEVICE void load_element( + typename Base::AccessType& value, int d, int h, int w, int c) const { + Base::load_element(value, d, h, w, c); + reinterpret_cast(value) &= mask_; } - /// Move to residue portion. - CUTLASS_DEVICE void move_to_residue(typename Base::Index k) { - Base::move_to_residue(k); - in_residue_ = true; - } - - /// Move back to the beginning of the first tile. - CUTLASS_DEVICE void rollback() { - Base::rollback(); - in_residue_ = false; - } - - /// Are we in the residue? - bool in_residue_; /// The mask to clean up the values. uint32_t mask_; }; diff --git a/cutlass/gemm/igemm_multiply_add.h b/cutlass/gemm/igemm_multiply_add.h index 5a8baec53..5ff6c7c1b 100644 --- a/cutlass/gemm/igemm_multiply_add.h +++ b/cutlass/gemm/igemm_multiply_add.h @@ -28,9 +28,9 @@ */ #pragma once -#include +#include "cutlass/fragment.h" -#include +#include "cutlass/gemm/thread_multiply_add.h" namespace cutlass { namespace gemm { @@ -38,16 +38,18 @@ namespace gemm { //////////////////////////////////////////////////////////////////////////////////////////////////// /// Template performing matrix multiply-add operation within a thread -template -struct ThreadMultiplyAdd { +template +struct ThreadMultiplyAdd { /// The shape of the instruction. typedef Shape<4, 1, 1> InstructionShape; - /// The number of accumulators per thread. - typedef AccumulatorsPerThread_ AccumulatorsPerThread; + /// Shape of the thread-level GEMM (K-by-N-by-M) + typedef ThreadGemmShape_ ThreadGemmShape; + /// Aliased for compatibility. Will be removed in CUTLASS v2.0 + typedef ThreadGemmShape AccumulatorsPerThread; /// The number of threads per warp. typedef ThreadsPerWarp_ ThreadsPerWarp; /// The number of accumulators per warp. - typedef typename ShapeMul::Shape AccumulatorsPerWarp; + typedef typename ShapeMul::Shape AccumulatorsPerWarp; /// The type for A. typedef int8_t ScalarA; /// The fragment for A. diff --git a/cutlass/gemm/igemm_swizzle.h b/cutlass/gemm/igemm_swizzle.h index 77cf7118d..fbb68d143 100644 --- a/cutlass/gemm/igemm_swizzle.h +++ b/cutlass/gemm/igemm_swizzle.h @@ -27,7 +27,7 @@ */ #pragma once -#include +#include "cutlass/fragment.h" namespace cutlass { namespace gemm { @@ -82,6 +82,11 @@ struct IgemmSwizzle { int a2 = src_int[i2]; int a3 = src_int[i3]; + // // DEBUG. + // if (threadIdx.x == 0) { + // printf("a=0x%08x 0x%08x 0x%08x 0x%08x\n", a0, a1, a2, a3); + // } + int b0, b1, b2, b3, c0; asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(b0) : "r"(a0), "r"(a1)); asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(c0) : "r"(a2), "r"(a3)); @@ -99,6 +104,11 @@ struct IgemmSwizzle { asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(c0) : "r"(a2), "r"(a3)); asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b3) : "r"(b3), "r"(c0)); + // // DEBUG. + // if (threadIdx.x == 0) { + // printf("b=0x%08x 0x%08x 0x%08x 0x%08x\n", b0, b1, b2, b3); + // } + dst_int[i0] = b0; dst_int[i1] = b1; dst_int[i2] = b2; diff --git a/cutlass/gemm/igemm_traits.h b/cutlass/gemm/igemm_traits.h index 82f8de5cd..5bceeda92 100644 --- a/cutlass/gemm/igemm_traits.h +++ b/cutlass/gemm/igemm_traits.h @@ -29,18 +29,18 @@ */ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "cutlass/convert.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_epilogue.h" +#include "cutlass/gemm/gemm_epilogue_traits.h" +#include "cutlass/gemm/gemm_global_tile.h" +#include "cutlass/gemm/gemm_shared_tile.h" +#include "cutlass/gemm/gemm_traits.h" +#include "cutlass/gemm/igemm_epilogue.h" +#include "cutlass/gemm/igemm_global_tile.h" +#include "cutlass/gemm/igemm_multiply_add.h" +#include "cutlass/gemm/igemm_swizzle.h" +#include "cutlass/reshape_tile.h" namespace cutlass { namespace gemm { @@ -52,49 +52,52 @@ template < typename OutputTile_, /// The output type. typename ScalarD_, - /// The number of accumulators per thread. - typename AccumulatorsPerThread_> -struct IgemmConfig - : public GemmConfig< - /// The scalar type for A. - int8_t, - /// The scalar type for B. - int8_t, - /// The scalar type for C. - ScalarD_, - /// The scalar type for D. - ScalarD_, - /// The tile size for the GEMM KxNxM. - OutputTile_, - /// The functor to do the math in the main loop. - ThreadMultiplyAdd, int8_t, int8_t, int>, - /// The number of scalars per LDG for A. - 4, - /// The number of scalars per STS for A. - 4, - /// The number of scalars per LDS for A. - 16, - /// The number of scalars per LDG for B. - 4, - /// The number of scalars per STS for B. - 4, - /// The number of scalars per LDS for B. - 16, - /// The number of scalars per LDG for C and STG for D. - 1, - /// The number of scalars per STS for D. - 4, - /// The number of scalars per LDS for D. - 1, - /// The number of stages in shared memory. - 2, - /// Enable the code path that deals with the residue in epilogue. - true> {}; + /// Tile size for thread-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_> +struct IgemmConfig : public GemmConfig< + /// The scalar type for A. + int8_t, + /// The scalar type for B. + int8_t, + /// The scalar type for C. + ScalarD_, + /// The scalar type for D. + ScalarD_, + /// The tile size for the GEMM KxNxM. + OutputTile_, + /// The functor to do the math in the main loop. + ThreadMultiplyAdd, int8_t, int8_t, int>, + /// The number of scalars per LDG for A. + 4, + /// The number of scalars per STS for A. + 4, + /// The number of scalars per LDS for A. + 16, + /// The number of scalars per LDG for B. + 4, + /// The number of scalars per STS for B. + 4, + /// The number of scalars per LDS for B. + 16, + /// The number of scalars per LDG for C and STG for D. + 1, + /// The number of scalars per STS for D. + 4, + /// The number of scalars per LDS for D. + 1, + /// The number of stages in shared memory. + 2, + /// kResidueSeparate + false, + /// kResidueInPrologue + false, + /// kLaunchBounds + false> {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct IgemmConfig +template +struct IgemmConfig : public GemmConfig< /// The scalar type for A. int8_t, @@ -107,7 +110,7 @@ struct IgemmConfig /// The tile size for the GEMM KxNxM. OutputTile_, /// The functor to do the math in the main loop. - ThreadMultiplyAdd, int8_t, int8_t, int>, + ThreadMultiplyAdd, int8_t, int8_t, int>, /// The number of scalars per LDG for A. 4, /// The number of scalars per STS for A. @@ -128,8 +131,12 @@ struct IgemmConfig 4, /// The number of stages in shared memory. 2, - /// Enable the code path that deals with the residue in epilogue. - true> {}; + /// If true, separate mainloop is instantiated from residue + false, + /// Compute residue in prolog? + true, + /// Launch bounds? + false> {}; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -162,7 +169,7 @@ struct IgemmTileTraitsHelperA GemmConfig_::kScalarsPerLdgA> GlobalTileTraits; - // The iterator. + /// The global load iterator. typedef GemmGlobalIteratorAb GlobalLoadIterator; /// The traits class to build the iterator to store data to shared memory for A^N. @@ -208,7 +215,7 @@ struct IgemmTileTraitsHelperA { GemmConfig_::kScalarsPerLdgA> GlobalTileTraits; - // The iterator. + /// The global load iterator. typedef IgemmGlobalIteratorAb GlobalLoadIterator; /// The traits class to build the iterator to store data to shared memory for A^N. @@ -281,7 +288,7 @@ struct IgemmTileTraitsHelperB { GemmConfig_::kScalarsPerLdgB> GlobalTileTraits; - // The iterator. + /// The global load iterator. typedef IgemmGlobalIteratorAb GlobalLoadIterator; /// The traits class to build the iterator to store data to shared memory for B^N. @@ -345,7 +352,7 @@ struct IgemmTileTraitsHelperB GemmConfig_::kScalarsPerLdgB> GlobalTileTraits; - // The iterator. + /// The global load iterator. typedef GemmGlobalIteratorAb GlobalLoadIterator; /// The traits class to build the iterator to store data to shared memory for B^N. @@ -404,13 +411,13 @@ template < typename ScalarD_, /// The functor to do the math in the epilogue. typename EpilogueFunctor_, - /// The number of accumulators per thread. - typename AccumulatorsPerThread_ = Shape<32, 8, 8>, + /// Tile size for thread-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_ = Shape<32, 8, 8>, /// The index. typename Index_ = int> struct IgemmTraitsHelper { /// The IGEMM config. - typedef IgemmConfig GemmConfig; + typedef IgemmConfig GemmConfig; /// The GEMM config for A. typedef IgemmTileTraitsHelperA GemmTileTraitsHelperA; /// The GEMM config for B. @@ -418,7 +425,6 @@ struct IgemmTraitsHelper { /// The iterator to load A from global memory. typedef typename GemmTileTraitsHelperA::GlobalLoadIterator GlobalLoadIteratorA; - /// The default transformer for A. typedef typename IgemmTransformerA::Transformer GlobalTransformerA; @@ -429,12 +435,14 @@ struct IgemmTraitsHelper { MemorySpace::kShared> SharedStoreIteratorA; /// The stream to load A from global memory to shared memory. - typedef GlobalLoadStream + typedef GlobalLoadStream GlobalLoadStreamA; /// The iterator to load B from global memory. typedef typename GemmTileTraitsHelperB::GlobalLoadIterator GlobalLoadIteratorB; - // The default transformer for B. typedef typename IgemmTransformerB::Transformer GlobalTransformerB; @@ -445,7 +453,10 @@ struct IgemmTraitsHelper { MemorySpace::kShared> SharedStoreIteratorB; /// The stream to load B from global memory to shared memory. - typedef GlobalLoadStream + typedef GlobalLoadStream GlobalLoadStreamB; /// The iterator to load A from shared memory. @@ -501,8 +512,8 @@ template < typename ScalarD_ = int, /// The functor to do the math in the epilogue. typename EpilogueFunctor_ = LinearScaling::Scalar>, - /// The number of accumulators per thread. - typename AccumulatorsPerThread_ = Shape<32, 8, 8>, + /// Tile size for thread-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_ = Shape<32, 8, 8>, /// The index. typename Index_ = int, /// The helper class. @@ -511,7 +522,7 @@ template < OutputTile_, ScalarD_, EpilogueFunctor_, - AccumulatorsPerThread_, + ThreadGemmShape_, Index_> > struct IgemmTraits : public GemmTraits< // The config. diff --git a/cutlass/gemm/linear_scaling.h b/cutlass/gemm/linear_scaling.h index 979c93f96..a12fc5f19 100644 --- a/cutlass/gemm/linear_scaling.h +++ b/cutlass/gemm/linear_scaling.h @@ -1,3 +1,4 @@ + /*************************************************************************************************** * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. * @@ -27,18 +28,31 @@ */ #pragma once -#include +#include "cutlass/fragment_multiply_add.h" namespace cutlass { namespace gemm { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +CUTLASS_DEVICE bool is_zero(T x) { + return x == T(0); +} + +#if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16) +CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast(x) == int16_t(0); } +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + /// Functor to compute linear combination of fragments -template > +template > struct LinearScaling { // The scalar. typedef Scalar_ Scalar; + // The accumulator Type + typedef typename FragmentMultiplyAdd_::ScalarAccum ScalarAccum; // The adapater. typedef FragmentMultiplyAdd_ FragmentMultiplyAdd; @@ -47,6 +61,21 @@ struct LinearScaling { /// The alpha/beta scaling params. Scalar alpha, beta; + // + // Methods + // + + // Constructor + CUTLASS_HOST_DEVICE + Params(Scalar _alpha = 0, Scalar _beta = 0) : alpha(_alpha), beta(_beta) {} + + /// Initialize the parameters + CUTLASS_HOST_DEVICE int initialize(Scalar _alpha, Scalar _beta) { + alpha = _alpha; + beta = _beta; + return 0; + } + /// Initialize the parameters. template CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) { @@ -56,14 +85,53 @@ struct LinearScaling { } }; + // + // Data members + // + + Params params; + + // + // Methods + // + /// Ctor. - CUTLASS_DEVICE LinearScaling(Params const& params) : alpha(params.alpha), beta(params.beta) {} + CUTLASS_DEVICE LinearScaling() { } + + /// Ctor. + CUTLASS_DEVICE LinearScaling(Params const& _params) : params(_params) {} + + /// Method to determine whether the source accumulator matrix C is ever needed. This method + /// may always safely return true, though better performance is possible if the source accumulator + /// matrix is never loaded unnecessarily. + CUTLASS_DEVICE + bool source_required() const { + return !is_zero(params.beta); + } /// Evaluate the functor. template CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_& output) { FragmentMultiplyAdd mad; - mad.multiply(alpha, accum, output); + mad.multiply(params.alpha, accum, output); + + } + + /// Evaluate the functor, without using fragment in the API + template + CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput *output) { + Fragment FragAccum; + Fragment FragOutput; +#pragma unroll + for (int i = 0; i < size; i++) { + FragAccum[i] = accum[i]; + FragOutput[i] = output[i]; + } + evaluate(FragAccum, FragOutput); +#pragma unroll + for (int i = 0; i < size; i++) { + output[i] = FragOutput[i]; + } } /// Evaluate the functor. @@ -71,12 +139,28 @@ struct LinearScaling { CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_ const& old, FragmentB_& output) { FragmentMultiplyAdd mad; FragmentB_ tmp; - mad.multiply(beta, old, tmp); - mad.multiply_add(alpha, accum, tmp, output); + mad.multiply(params.beta, old, tmp); + mad.multiply_add(params.alpha, accum, tmp, output); } - /// The alpha/beta scaling factors. - Scalar alpha, beta; + /// Evaluate the functor, without using fragment in the API + template + CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput const *old, ScalarOutput *output) { + Fragment FragAccum; + Fragment FragOutput; + Fragment FragOld; +#pragma unroll + for (int i = 0; i < size; i++) { + FragAccum[i] = accum[i]; + FragOutput[i] = output[i]; + FragOld[i] = old[i]; + } + evaluate(FragAccum, FragOld, FragOutput); +#pragma unroll + for (int i = 0; i < size; i++) { + output[i] = FragOutput[i]; + } + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass/gemm/linear_scaling_device_ptr.h b/cutlass/gemm/linear_scaling_device_ptr.h new file mode 100644 index 000000000..5dc845da4 --- /dev/null +++ b/cutlass/gemm/linear_scaling_device_ptr.h @@ -0,0 +1,149 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements the BLAS linear scaling function alpha*AB + beta*C +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/scalar_or_pointer.h" +#include "cutlass/gemm/linear_scaling.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace gemm { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Functor to compute linear combination of fragments. This is intended to support passing scalars +/// either by value from the host or by reference to device-side scalar elements. This is inspired +/// by cuBLAS's device pointer mode. +template > +struct LinearScalingDevicePtr : public LinearScaling { + + /// Linear Scaling class used + typedef LinearScaling Base; + + // The scalar. + typedef typename Base::Scalar Scalar; + + /// The parameters. + class Params { + private: + /// Alpha scalar + detail::ScalarOrPointer alpha_; + + /// Beta sclaar + detail::ScalarOrPointer beta_; + + public: + // + // Methods + // + + // Constructor + CUTLASS_HOST_DEVICE + Params() {} + + // Constructor + CUTLASS_HOST_DEVICE + Params( + Scalar alpha, + Scalar beta + ): + alpha_(alpha), + beta_(beta) {} + + // Constructor + CUTLASS_HOST_DEVICE + Params( + Scalar const *alpha_ptr, + Scalar const *beta_ptr + ): + alpha_(alpha_ptr), + beta_(alpha_ptr) {} + + /// Initialize the parameters + CUTLASS_HOST_DEVICE int initialize( + Scalar alpha, + Scalar beta) { + + alpha_ = alpha; + beta_ = beta; + + return 0; + } + + /// Initialize the parameters + CUTLASS_HOST_DEVICE int initialize( + Scalar const *alpha, + Scalar const *beta) { + + alpha_ = alpha; + beta_= beta; + + return 0; + } + + /// Initialize the parameters. + template + CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) { + + alpha_ = desc.alpha; + beta_ = desc.beta; + + return 0; + } + + /// Gets the alpha scalar + CUTLASS_HOST_DEVICE + Scalar alpha() const { + return alpha_; + } + + /// Gets the beta scalar + CUTLASS_HOST_DEVICE + Scalar beta() const { + return beta_; + } + }; + + // + // Methods + // + + /// Ctor. + CUTLASS_HOST_DEVICE LinearScalingDevicePtr(Params const& _params) { + this->params.alpha = _params.alpha(); + this->params.beta = _params.beta(); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/scalar_or_pointer.h b/cutlass/gemm/scalar_or_pointer.h new file mode 100644 index 000000000..7c4b4b75d --- /dev/null +++ b/cutlass/gemm/scalar_or_pointer.h @@ -0,0 +1,129 @@ + +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements the BLAS linear scaling function alpha*AB + beta*C +*/ +#pragma once + +#include "cutlass/cutlass.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// Helper class defines an object which operates as either a scalar or a pointer. If the pointer +/// is non-null, it is dereferenced when the object is accessed. +template +class ScalarOrPointer { +public: + /// Underlying scalar type + typedef Scalar_ Scalar; + +private: + // + // Data members + // + + /// Scalar value + Scalar scalar; + + /// Pointer to use if non null + Scalar const *ptr; + +public: + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + ScalarOrPointer(): scalar(0), ptr(nullptr) {} + + /// Object behaves as a scalar + CUTLASS_HOST_DEVICE + ScalarOrPointer(Scalar const &val): scalar(val), ptr(nullptr) {} + + /// Object behaves as a scalar + CUTLASS_HOST_DEVICE + ScalarOrPointer(Scalar const *ptr_): scalar(0), ptr(ptr_) {} + + /// Returns true if is pointer + CUTLASS_HOST_DEVICE + bool is_pointer() const { + return bool(ptr); + } + + /// Gets the pointer value + CUTLASS_HOST_DEVICE + Scalar const *get_ptr() const { + return ptr; + } + + /// Gets the pointer value + CUTLASS_HOST_DEVICE + Scalar get_scalar() const { + return scalar; + } + + /// Assigns to a scalar and sets pointer to nullptr + CUTLASS_HOST_DEVICE + ScalarOrPointer &operator=(Scalar const &scalar_) { + scalar = scalar_; + ptr = nullptr; + return *this; + } + + /// Assigns to a pointer value + CUTLASS_HOST_DEVICE + ScalarOrPointer &operator=(Scalar const *ptr_) { + ptr = ptr_; + return *this; + } + + /// Access the element + CUTLASS_HOST_DEVICE + Scalar get() const { + if (ptr) { + return *ptr; + } + return scalar; + } + + /// Accesses the element + CUTLASS_HOST_DEVICE + operator Scalar() const { + return get(); + } +}; + +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/cutlass/gemm/sgemm_traits.h b/cutlass/gemm/sgemm_traits.h index 66b767748..8ce7f58e2 100644 --- a/cutlass/gemm/sgemm_traits.h +++ b/cutlass/gemm/sgemm_traits.h @@ -27,13 +27,13 @@ */ #pragma once -#include -#include -#include -#include -#include -#include -#include +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_epilogue.h" +#include "cutlass/gemm/gemm_epilogue_traits.h" +#include "cutlass/gemm/gemm_global_tile.h" +#include "cutlass/gemm/gemm_shared_tile.h" +#include "cutlass/gemm/gemm_traits.h" +#include "cutlass/gemm/thread_multiply_add.h" namespace cutlass { namespace gemm { @@ -43,46 +43,53 @@ namespace gemm { template < /// The tile size for the GEMM KxNxM. typename OutputTile_, - /// The number of accumulators per thread. - typename AccumulatorsPerThread_, + /// Tile size for thread-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_, /// The number of scalars per LDG for A. int kScalarsPerLdgA_ = 1, /// The number of scalars per LDG for B. - int kScalarsPerLdgB_ = 1> -struct SgemmConfig - : public GemmConfig< - /// The scalar type for A. - float, - /// The scalar type for B. - float, - /// The scalar type for C. - float, - /// The scalar type for D. - float, - /// The tile size for the GEMM KxNxM. - OutputTile_, - /// The functor to do the math in the main loop. - ThreadMultiplyAdd, float, float, float>, - /// The number of scalars per LDG for A. - kScalarsPerLdgA_, - /// The number of scalars per STS for A. - kScalarsPerLdgA_, - /// The number of scalars per LDS for A. - 4, - /// The number of scalars per LDG for B. - kScalarsPerLdgB_, - /// The number of scalars per STS for B. - kScalarsPerLdgB_, - /// The number of scalars per LDS for B. - 4, - /// The number of scalars per LDG for C and STG for D. - 1, - /// The number of scalars per STS for D. - 4, - /// The number of scalars per LDS for D. - 1, - /// The number of stages in shared memory. - 2> {}; + int kScalarsPerLdgB_ = 1, + /// Whether to specify launch bounds + bool kLaunchBounds = true> +struct SgemmConfig : public GemmConfig< + /// The scalar type for A. + float, + /// The scalar type for B. + float, + /// The scalar type for C. + float, + /// The scalar type for D. + float, + /// The tile size for the GEMM KxNxM. + OutputTile_, + /// The functor to do the math in the main loop. + ThreadMultiplyAdd, float, float, float>, + /// The number of scalars per LDG for A. + kScalarsPerLdgA_, + /// The number of scalars per STS for A. + kScalarsPerLdgA_, + /// The number of scalars per LDS for A. + 4, + /// The number of scalars per LDG for B. + kScalarsPerLdgB_, + /// The number of scalars per STS for B. + kScalarsPerLdgB_, + /// The number of scalars per LDS for B. + 4, + /// The number of scalars per LDG for C and STG for D. + 1, + /// The number of scalars per STS for D. + 4, + /// The number of scalars per LDS for D. + 1, + /// The number of stages in shared memory. + 2, + /// kResidueSeparate + false, + /// kResidueInPrologue + true, + /// kLaunchBounds + kLaunchBounds> {}; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -95,8 +102,8 @@ template < typename OutputTile_ = Shape<8, 128, 128>, /// The functor to use in the epilogue. typename EpilogueFunctor_ = LinearScaling, - /// The number of accumulators per thread. - typename AccumulatorsPerThread_ = Shape<8, 8, 8>, + /// Tile size for thread-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_ = Shape<8, 8, 8>, /// The number of floats loaded in one LDG for A. int kScalarsPerLdgA_ = 1, /// The number of floats loaded in one LDG for B. @@ -105,7 +112,7 @@ template < typename Index_ = int, /// The SGEMM config. typename GemmConfig_ = - SgemmConfig, + SgemmConfig, /// The traits class for the epilogue. typename GemmEpilogueTraits_ = SimplifiedGemmEpilogueTraits > @@ -123,5 +130,43 @@ struct SgemmTraits : public SimplifiedGemmTraits< //////////////////////////////////////////////////////////////////////////////////////////////////// +/// Helper to define SGEMM traits using Launch Bounds +template < + /// The layout for A. + MatrixLayout::Kind kLayoutA_, + /// The layout for B. + MatrixLayout::Kind kLayoutB_, + /// The output tile. + typename OutputTile_ = Shape<8, 128, 128>, + /// The functor to use in the epilogue. + typename EpilogueFunctor_ = LinearScaling, + /// Tile size for thread-level GEMM (K-by-N-by-M) + typename ThreadGemmShape_ = Shape<8, 8, 8>, + /// The number of floats loaded in one LDG for A. + int kScalarsPerLdgA_ = 1, + /// The number of floats loaded in one LDG for B. + int kScalarsPerLdgB_ = 1, + /// The index. + typename Index_ = int, + /// The SGEMM config. + typename GemmConfig_ = + SgemmConfig, + /// The traits class for the epilogue. + typename GemmEpilogueTraits_ = + SimplifiedGemmEpilogueTraits > +struct SgemmLBTraits : public SimplifiedGemmTraits< + // The layout for A. + kLayoutA_, + // The layout for B. + kLayoutB_, + // The config. + GemmConfig_, + // The epilogue. + GemmEpilogue, + // The index. + Index_> {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace gemm } // namespace cutlass diff --git a/cutlass/gemm/thread_multiply_add.h b/cutlass/gemm/thread_multiply_add.h index 20dca1596..b95dee58a 100644 --- a/cutlass/gemm/thread_multiply_add.h +++ b/cutlass/gemm/thread_multiply_add.h @@ -27,7 +27,7 @@ */ #pragma once -#include +#include "cutlass/fragment.h" namespace cutlass { namespace gemm { @@ -35,20 +35,23 @@ namespace gemm { //////////////////////////////////////////////////////////////////////////////////////////////////// /// Template performing matrix multiply-add operation within a thread -template + typename ScalarC_, + MatrixLayout::Kind kLayout_ = MatrixLayout::kColumnMajor> struct ThreadMultiplyAdd { /// The shape of the instruction. typedef Shape<1, 1, 1, 1> InstructionShape; - /// The number of accumulators per thread. - typedef AccumulatorsPerThread_ AccumulatorsPerThread; + /// The shape of a thread-leveel matrix multiply accumulate. + typedef ThreadGemmShape_ ThreadGemmShape; + /// Aliased to "AccumulatorsPerThread" for compatibility. Expect to be renamed in CUTLASS v2.0 + typedef ThreadGemmShape AccumulatorsPerThread; /// The number of threads per warp. typedef ThreadsPerWarp_ ThreadsPerWarp; /// The number of accumulators per warp. - typedef typename ShapeMul::Shape AccumulatorsPerWarp; + typedef typename ShapeMul::Shape AccumulatorsPerWarp; /// The type for A. typedef ScalarA_ ScalarA; /// The fragment for A. @@ -70,9 +73,18 @@ struct ThreadMultiplyAdd { FragmentB const& b, Accumulators const& c, Accumulators& d) { - for (int j = 0; j < AccumulatorsPerThread::kH; ++j) { - for (int i = 0; i < AccumulatorsPerThread::kW; ++i) { - d[j * AccumulatorsPerThread::kW + i] = a[i] * b[j] + c[j * AccumulatorsPerThread::kW + i]; + if(kLayout_ == MatrixLayout::kColumnMajor) { + for (int j = 0; j < AccumulatorsPerThread::kH; ++j) { + for (int i = 0; i < AccumulatorsPerThread::kW; ++i) { + d[j * AccumulatorsPerThread::kW + i] = a[i] * b[j] + c[j * AccumulatorsPerThread::kW + i]; + } + } + } + else { + for(int i = 0; i < AccumulatorsPerThread::kW; ++i) { + for(int j = 0; j < AccumulatorsPerThread::kH; ++j) { + d[i * AccumulatorsPerThread::kH + j] = a[i] * b[j] + c[i * AccumulatorsPerThread::kH + j]; + } } } } diff --git a/cutlass/gemm/threadblock_swizzle.h b/cutlass/gemm/threadblock_swizzle.h new file mode 100644 index 000000000..fe7a3be7f --- /dev/null +++ b/cutlass/gemm/threadblock_swizzle.h @@ -0,0 +1,387 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defies functors for mapping blockIdx to partitions of the GEMM computation. +*/ +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/gemm/gemm_coord.h" + +namespace cutlass { +namespace gemm { + +struct swizzleDirection { + enum Kind { Boustrophedon, OneDirection }; +}; +// helper template function +template +CUTLASS_DEVICE int getLinearIdx(int groups) { + // groupCols is not needed for OneDirection Swizzle + return blockIdx.y * gridDim.x + blockIdx.x; +} +template <> +CUTLASS_DEVICE int getLinearIdx(int groups) { + // reverse blockIdx.x for some columns + if ((blockIdx.y / groups) % 2 == 1) + return blockIdx.y * gridDim.x + (gridDim.x - blockIdx.x - 1); + else + return blockIdx.y * gridDim.x + blockIdx.x; +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/*!@defgroup IdentityBlockSwizzle Identity Block Swizzle +@{ + Block Swizzle provides the mapping logic between a block in the physical memory of Matrix C and +Thread Block + Identiy Block Swizzle effective maps blocks in leading dimension order (column major) with +thread block + in leading dimension order (blockIdx.x) + blockIdx.z is mapped with batch_count for batched GEMM +@} +*/ +struct IdentityBlockSwizzle { + /// Ctor. aka ColumnMajorBlockSwizzle<1> + CUTLASS_HOST_DEVICE IdentityBlockSwizzle() {} + + /// Swizzle the block index. + CUTLASS_DEVICE dim3 swizzle() { return blockIdx; } + + /// + CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, + Coord<3> const &OutputTile) { + /*OutputTile and problem_size are both in KNM order*/ + dim3 grid; + grid.x = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2]; + grid.y = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1]; + grid.z = problem_size.batch(); + return grid; + } + + /// + CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) { + dim3 block = swizzle(); + Coord<3> threadblock_offset = + make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]); + return threadblock_offset; + } + + /// + CUTLASS_DEVICE int get_batch_id() { + dim3 block = swizzle(); + return block.z; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/* +ColumnMajorBlockSwizzle<1, OneDirection> is equivalent with IdentityBlockSwizzle +groupCols has the effect of controlling the schedulling of thread blocks +settings with different groupCols can contribute to the overall performance by affecting L2 cache +hit rate + +consider a regular thread block mapping btween matrix C and different thread blocks +note that C is column major, and the leading dimension of thread block id is blockIdx.x + +let's look at an example where gridIdx.x = 6, gridIdx.y = 7, gridIdx.z = 1 +(blockIdx.x, blockIdx.y) +mapping between threadblockID and C matrix: +------------------------------------------------------- +(0,0) | (0,1) | (0,2) | (0,3) | (0,4) | (0,5) | (0,6) | +------------------------------------------------------- +(1,0) | (1,1) | (1,2) | (1,3) | (1,4) | (1,5) | (1,6) | +------------------------------------------------------- +(2,0) | (2,1) | (2,2) | (2,3) | (2,4) | (2,5) | (2,6) | +------------------------------------------------------- +(3,0) | (3,1) | (3,2) | (3,3) | (3,4) | (3,5) | (3,6) | +------------------------------------------------------- +(4,0) | (4,1) | (4,2) | (4,3) | (4,4) | (4,5) | (4,6) | +------------------------------------------------------- +(5,0) | (5,1) | (5,2) | (5,3) | (5,4) | (5,5) | (5,6) | +------------------------------------------------------- + +A ColumnMajorBlockSwizzle<1, OneDirection> will imply the above order where threadblocks are +launched in a column major + +A ColumnMajorBlockSwizzle<2, OneDirection> swizzles things a little, +------------------------------------------------------- +(0,0) | (3,0) | (0,2) | (3,2) | (0,4) | (3,4) | (0,6) | +------------------------------------------------------- +(0,1) | (3,1) | (0,3) | (3,3) | (0,5) | (3,5) | (1,6) | +------------------------------------------------------- +(1,0) | (4,0) | (1,2) | (4,2) | (1,4) | (4,4) | (2,6) | +------------------------------------------------------- +(1,1) | (4,1) | (1,3) | (4,3) | (1,5) | (4,5) | (3,6) | +------------------------------------------------------- +(2,0) | (5,0) | (2,2) | (5,2) | (2,4) | (5,4) | (4,6) | +------------------------------------------------------- +(2,1) | (5,1) | (2,3) | (5,3) | (2,5) | (5,5) | (5,6) | +------------------------------------------------------- + +so in memory, it would apprear that we work on 2 columns at a time rather than 1 +Note that the index here really represent how each block maps to memory + +A ColumnMajorBlockSwizzle<1, Boustrophedon> is similar to ColumnMajorBlockSwizzle<1, OneDirection> +except that every column flips the ordering against the previous one +------------------------------------------------------- +(0,0) | (5,1) | (0,2) | (5,3) | (0,4) | (5,5) | (0,6) | +------------------------------------------------------- +(1,0) | (4,1) | (1,2) | (4,3) | (1,4) | (4,5) | (1,6) | +------------------------------------------------------- +(2,0) | (3,1) | (2,2) | (3,3) | (2,4) | (3,5) | (2,6) | +------------------------------------------------------- +(3,0) | (2,1) | (3,2) | (2,3) | (3,4) | (2,5) | (3,6) | +------------------------------------------------------- +(4,0) | (1,1) | (4,2) | (1,3) | (4,4) | (1,5) | (4,6) | +------------------------------------------------------- +(5,0) | (0,1) | (5,2) | (0,3) | (5,4) | (0,5) | (5,6) | +------------------------------------------------------- + +similarily, A ColumnMajorBlockSwizzle<2, Boustrophedon> looks like +------------------------------------------------------- +(0,0) | (3,0) | (2,3) | (5,3) | (0,4) | (3,4) | (5,6) | +------------------------------------------------------- +(0,1) | (3,1) | (2,2) | (5,2) | (0,5) | (3,5) | (4,6) | +------------------------------------------------------- +(1,0) | (4,0) | (1,3) | (4,3) | (1,4) | (4,4) | (3,6) | +------------------------------------------------------- +(1,1) | (4,1) | (1,2) | (4,2) | (1,5) | (4,5) | (2,6) | +------------------------------------------------------- +(2,0) | (5,0) | (0,3) | (3,3) | (2,4) | (5,4) | (1,6) | +------------------------------------------------------- +(2,1) | (5,1) | (0,2) | (3,2) | (2,5) | (5,5) | (0,6) | +------------------------------------------------------- + +*/ + +template +struct ColumnMajorBlockSwizzle { + /// Ctor. + CUTLASS_HOST_DEVICE ColumnMajorBlockSwizzle() {} + + /// Swizzle the block index. + CUTLASS_DEVICE dim3 swizzle() { + assert(gridDim.z == 1); + int linearIdx = getLinearIdx(groupCols); + dim3 swizzledBlockIdx; + int currGroupCols = groupCols; + int prevGroupCols = groupCols; + + if ((gridDim.y % groupCols != 0) && ((blockIdx.y + (gridDim.y % groupCols)) >= gridDim.y)) { + // last colmuns if gridDim.y is not divisble by groupCols + currGroupCols = gridDim.y % groupCols; + } + + swizzledBlockIdx.x = (linearIdx / currGroupCols) % gridDim.x; + swizzledBlockIdx.y = + linearIdx % currGroupCols + prevGroupCols * (linearIdx / (prevGroupCols * gridDim.x)); + swizzledBlockIdx.z = blockIdx.z; + + return swizzledBlockIdx; + } + + /// + CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, + Coord<3> const &OutputTile) { + dim3 grid; + grid.x = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2]; + grid.y = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1]; + grid.z = problem_size.batch(); + return grid; + } + + /// + CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) { + dim3 block = swizzle(); + Coord<3> threadblock_offset = + make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]); + return threadblock_offset; + } + + /// + CUTLASS_DEVICE int get_batch_id() { + dim3 block = swizzle(); + return block.z; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/* + +consider a regular thread block mapping btween matrix C and different thread blocks +note that C is column major, and the leading dimension of thread block id is blockIdx.x + +let's look at an example where gridIdx.x = 6, gridIdx.y = 7, gridIdx.z = 1 +(blockIdx.x, blockIdx.y) +mapping between threadblockID and C matrix: +------------------------------------------------------- +(0,0) | (0,1) | (0,2) | (0,3) | (0,4) | (0,5) | (0,6) | +------------------------------------------------------- +(1,0) | (1,1) | (1,2) | (1,3) | (1,4) | (1,5) | (1,6) | +------------------------------------------------------- +(2,0) | (2,1) | (2,2) | (2,3) | (2,4) | (2,5) | (2,6) | +------------------------------------------------------- +(3,0) | (3,1) | (3,2) | (3,3) | (3,4) | (3,5) | (3,6) | +------------------------------------------------------- +(4,0) | (4,1) | (4,2) | (4,3) | (4,4) | (4,5) | (4,6) | +------------------------------------------------------- +(5,0) | (5,1) | (5,2) | (5,3) | (5,4) | (5,5) | (5,6) | +------------------------------------------------------- + +A RowMajorBlockSwizzle<1, OneDirection> will effectively transpose the map + +----------------------------------------------- +(0,0) | (1,0) | (2,0) | (3,0) | (4,0) | (5,0) | +----------------------------------------------- +(0,1) | (1,1) | (2,1) | (3,1) | (4,1) | (5,1) | +----------------------------------------------- +(0,2) | (1,2) | (2,2) | (3,2) | (4,2) | (5,2) | +----------------------------------------------- +(0,3) | (1,3) | (2,3) | (3,3) | (4,3) | (5,3) | +----------------------------------------------- +(0,4) | (1,4) | (2,4) | (3,4) | (4,4) | (5,4) | +--------------------------------------------- +(0,5) | (1,5) | (2,5) | (3,5) | (4,5) | (5,5) | +----------------------------------------------- +(0,6) | (1,6) | (2,6) | (3,6) | (4,6) | (5,6) | +----------------------------------------------- + +It would aprear in memory we are working on 1 row at a time + +A ColumnMajorBlockSwizzle<2, OneDirection> swizzles things a little bit more +----------------------------------------------- +(0,0) | (1,3) | (2,0) | (3,3) | (4,0) | (5,3) | +----------------------------------------------- +(1,0) | (0,4) | (3,0) | (2,4) | (5,0) | (4,4) | +----------------------------------------------- +(0,1) | (1,4) | (2,1) | (3,4) | (4,1) | (5,4) | +----------------------------------------------- +(1,1) | (0,5) | (3,1) | (2,5) | (5,1) | (4,5) | +----------------------------------------------- +(0,2) | (1,5) | (2,2) | (3,5) | (4,2) | (5,5) | +--------------------------------------------- +(1,2) | (0,6) | (3,2) | (2,6) | (5,2) | (4,6) | +----------------------------------------------- +(0,3) | (1,6) | (2,3) | (3,6) | (4,3) | (5,6) | +----------------------------------------------- + +so in memory, it would apprear that we work on 2 rows at a time rather than 1 row +Note that the index here really represent how each block maps to memory + +A RowMajorBlockSwizzle<1, Boustrophedon> is similar to RowMajorBlockSwizzle<1, OneDirection> +except that every column flips the ordering against the previous one + +----------------------------------------------- +(0,0) | (1,6) | (2,0) | (3,6) | (4,0) | (5,6) | +----------------------------------------------- +(0,1) | (1,5) | (2,1) | (3,5) | (4,1) | (5,5) | +----------------------------------------------- +(0,2) | (1,4) | (2,2) | (3,4) | (4,2) | (5,4) | +----------------------------------------------- +(0,3) | (1,3) | (2,3) | (3,3) | (4,3) | (5,3) | +----------------------------------------------- +(0,4) | (1,2) | (2,4) | (3,2) | (4,4) | (5,2) | +--------------------------------------------- +(0,5) | (1,1) | (2,5) | (3,1) | (4,5) | (5,1) | +----------------------------------------------- +(0,6) | (1,0) | (2,6) | (3,0) | (4,6) | (5,0) | +----------------------------------------------- + +similarily, A RowMajorBlockSwizzle<2, Boustrophedon> looks like +----------------------------------------------- +(0,0) | (1,3) | (2,3) | (3,6) | (4,0) | (5,3) | +----------------------------------------------- +(1,0) | (0,4) | (3,2) | (2,6) | (5,0) | (4,4) | +----------------------------------------------- +(0,1) | (1,4) | (2,2) | (3,5) | (4,1) | (5,4) | +----------------------------------------------- +(1,1) | (0,5) | (3,1) | (2,5) | (5,1) | (4,5) | +----------------------------------------------- +(0,2) | (1,5) | (2,1) | (3,4) | (4,2) | (5,5) | +--------------------------------------------- +(1,2) | (0,6) | (3,0) | (2,4) | (5,2) | (4,6) | +----------------------------------------------- +(0,3) | (1,6) | (2,0) | (3,3) | (4,3) | (5,6) | +----------------------------------------------- + +*/ + +template +struct RowMajorBlockSwizzle { + /// Ctor. + CUTLASS_HOST_DEVICE RowMajorBlockSwizzle() {} + + /// Swizzle the block index. + CUTLASS_DEVICE dim3 swizzle() { + assert(gridDim.z == 1); + int linearIdx = getLinearIdx(groupRows); + dim3 swizzledBlockIdx; + int currGroupRows = groupRows; + int prevGroupRows = groupRows; + + if ((gridDim.y % groupRows != 0) && ((blockIdx.y + (gridDim.y % groupRows)) >= gridDim.y)) { + // last columns + currGroupRows = gridDim.y % groupRows; + } + + swizzledBlockIdx.x = + linearIdx % currGroupRows + prevGroupRows * (linearIdx / (prevGroupRows * gridDim.x)); + swizzledBlockIdx.y = (linearIdx / currGroupRows) % gridDim.x; + swizzledBlockIdx.z = blockIdx.z; + + return swizzledBlockIdx; + } + + /// + CUTLASS_HOST_DEVICE dim3 get_grid_layout(GemmCoord const &problem_size, + Coord<3> const &OutputTile) { + dim3 grid; + grid.x = (problem_size.n() + OutputTile[1] - 1) / OutputTile[1]; + grid.y = (problem_size.m() + OutputTile[2] - 1) / OutputTile[2]; + grid.z = problem_size.batch(); + return grid; + } + + /// + CUTLASS_DEVICE Coord<3> get_threadblock_offset(Coord<3> const &OutputTile) { + dim3 block = swizzle(); + Coord<3> threadblock_offset = + make_Coord(0, block.y * OutputTile[1], block.x * OutputTile[2]); + return threadblock_offset; + } + + /// + CUTLASS_DEVICE int get_batch_id() { + dim3 block = swizzle(); + return block.z; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass diff --git a/cutlass/gemm/wmma_gemm_epilogue_traits.h b/cutlass/gemm/wmma_gemm_epilogue_traits.h index 0fafacf90..f35264dda 100644 --- a/cutlass/gemm/wmma_gemm_epilogue_traits.h +++ b/cutlass/gemm/wmma_gemm_epilogue_traits.h @@ -27,18 +27,18 @@ */ #pragma once -#include +#include "cutlass/wmma_matrix.h" #ifdef CUTLASS_USE_WMMA_API -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "cutlass/convert.h" +#include "cutlass/coord.h" +#include "cutlass/gemm/gemm_global_stream.h" +#include "cutlass/gemm/gemm_shared_stream.h" +#include "cutlass/gemm/linear_scaling.h" +#include "cutlass/gemm/wmma_gemm_global_tile.h" +#include "cutlass/gemm/wmma_gemm_shared_tile.h" +#include "cutlass/reshape_tile.h" +#include "cutlass/tile_iterator.h" namespace cutlass { namespace gemm { @@ -89,7 +89,7 @@ struct WmmaGemmEpilogueTraitsHelper { MemorySpace::kShared, Index_, WmmaMatrix, - IteratorFragment::kWmmaMatrix> + FragmentElementType::kWmmaMatrix> SharedStoreIteratorD; /// The shared store transformer for D. @@ -114,6 +114,9 @@ struct WmmaGemmEpilogueTraitsHelper { MemorySpace::kShared> SharedLoadIteratorD; + /// The stream to load D. + typedef SharedLoadStream SharedLoadStreamD; + /// The traits class to build the iterator to load data from global memory for C^N. typedef WmmaGemmGlobalIteratorCdTraits< // The pointer is float const. diff --git a/cutlass/gemm/wmma_gemm_global_tile.h b/cutlass/gemm/wmma_gemm_global_tile.h index dbd57f6b5..ce369d0eb 100644 --- a/cutlass/gemm/wmma_gemm_global_tile.h +++ b/cutlass/gemm/wmma_gemm_global_tile.h @@ -27,7 +27,7 @@ */ #pragma once -#include +#include "cutlass/gemm/gemm_global_tile.h" namespace cutlass { namespace gemm { @@ -68,22 +68,13 @@ struct WmmaGemmGlobalIteratorCdTraits : public GemmGlobalTileTraits -struct WmmaGemmGlobalIteratorCd : public TileIteratorBase { +struct WmmaGemmGlobalIteratorCd : public GemmGlobalIteratorCd { /// This class. typedef WmmaGemmGlobalIteratorCd This_; /// The traits. typedef TileTraits_ Traits; /// The base class. - typedef TileIteratorBase - Base; + typedef GemmGlobalIteratorCd Base; /// Override the strides in each dimension between different loads/stores. typedef Shape<0, 0, Base::Delta::kW, Base::Delta::kC> ImmediateOffsetStrides; /// The layout. @@ -99,47 +90,36 @@ struct WmmaGemmGlobalIteratorCd : public TileIteratorBasepointer = pointer; + BaseParams::pointer = pointer; + // Stride between GEMMs + BaseParams::stride_d = batch_stride; // Setup the base stride. One "group of threads" per column. - stride_h = ld; + BaseParams::stride_h = ldm; // Each thread output 1 column per iteration. . - inc_h = ld * TileTraits_::Threads::kH; - inc_advance = inc_h + epilogue_stride_w; + BaseParams::inc_h = ldm * TileTraits_::Threads::kH; + BaseParams::inc_advance = BaseParams::inc_h + epilogue_stride_w; - predicate_offset = n; - predicate_inc_h = TileTraits_::Threads::kH; - predicate_inc_advance = predicate_inc_h + epilogue_delta_w; + BaseParams::predicate_offset = n; + BaseParams::predicate_inc_h = TileTraits_::Threads::kH; + BaseParams::predicate_inc_advance = BaseParams::predicate_inc_h + epilogue_delta_w; - // It worked. return 0; } }; - Params params; - - Coord<4> thread_offset; - - /// Ctor. - CUTLASS_DEVICE WmmaGemmGlobalIteratorCd() {} - /// Ctor. CUTLASS_DEVICE WmmaGemmGlobalIteratorCd(Params const& params, const Coord<3>& bounds, @@ -148,61 +128,37 @@ struct WmmaGemmGlobalIteratorCd : public TileIteratorBaseparams.pointer += ((h * params.stride_h + w) + pointer_offset); + : Base(params, bounds, block, pointer_offset, pred_offset, thread_offset_func) {} - // Prepare the vector of predicates. - for (int i = 0; i < Base::Iterations::kW; ++i) { - predicates.set(i, w + i * Base::Delta::kW < bounds[2]); - } - this->params.predicate_offset -= (h + pred_offset); + /// Loads a single fragment element from memory + CUTLASS_DEVICE void load_element( + typename Base::AccessType& value, int d, int h, int w, int c) const { + Base::load_element(value, d, h, w, c); } - /// The accessor. - CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const { - int const imm = - ComputeOffsetFromStrides::get(0, 0, w, c); - Load::load(value, params.pointer, imm); - } - - /// Increment the pointer in the C dimension. - CUTLASS_DEVICE void inc_c() {} - /// Increment the pointer in the W dimension. - CUTLASS_DEVICE void inc_w() {} - /// Increment the pointer in the H dimension. - CUTLASS_DEVICE void inc_h() { - params.pointer += params.inc_h; - params.predicate_offset -= params.predicate_inc_h; - } - /// Increment the pointer in the D dimension. - CUTLASS_DEVICE void inc_d() {} - /// Increment the pointer to move to the next iteration. - CUTLASS_DEVICE void inc_advance() { - params.pointer += params.inc_advance; - params.predicate_offset -= params.predicate_inc_advance; - } - - /// The accessor. - CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) { - int const imm = + /// Stores a single fragment element into memory + CUTLASS_DEVICE void store_element( + typename Base::AccessType const& value, int d, int h, int w, int c) { + int const offset = ComputeOffsetFromStrides::get(d, h, w, 0); - Store::store( - value, params.pointer, imm); + Store::store(value, Base::params.pointer, offset); } - /// Test the predicate. - CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { - return predicates.at(w) && params.predicate_offset > 0; + public: + template + CUTLASS_DEVICE void load_post_increment(Fragment& fragment) { + Base::load_post_increment(fragment); } - /// The predicates for the row. - cutlass::PredicateVector predicates; + template + CUTLASS_DEVICE void store_post_increment(Fragment& fragment) { + Base::store_post_increment(fragment); + } }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass/gemm/wmma_gemm_multiply_add.h b/cutlass/gemm/wmma_gemm_multiply_add.h index 5968350e0..328e43adb 100644 --- a/cutlass/gemm/wmma_gemm_multiply_add.h +++ b/cutlass/gemm/wmma_gemm_multiply_add.h @@ -27,9 +27,9 @@ */ #pragma once -#include +#include "cutlass/wmma_matrix.h" #ifdef CUTLASS_USE_WMMA_API -#include +#include "cutlass/fragment.h" namespace cutlass { namespace gemm { @@ -42,15 +42,17 @@ template struct WmmaGemmMultiplyAdd { /// The shape of the instruction. typedef InstructionShape_ InstructionShape; /// The number of threads per warp. That's a dummy configuration. typedef Shape<1, InstructionShape_::kH, InstructionShape_::kW> ThreadsPerWarp; - /// The dimensions. - typedef AccumulatorsPerWarp_ AccumulatorsPerWarp; + /// Dimensions of the warp-level GEMM (K-by-N-by-M) + typedef WarpGemmShape_ WarpGemmShape; + /// Aliased for compatibility. Will be removed in CUTLASS v2.0 + typedef WarpGemmShape_ AccumulatorsPerWarp; /// The type for A. typedef ScalarA_ ScalarA; /// The type for B. @@ -102,6 +104,251 @@ struct WmmaGemmMultiplyAdd { //////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef CUTLASS_USE_SUBBYTE_WMMA +/// Specialization for WMMA GEMM with binary operands +template +struct WmmaGemmMultiplyAdd , + MatrixLayout::kColumnMajor, + Vector, + MatrixLayout::kColumnMajor, + int, + WarpGemmShape_, + Shape<128, 8, 8> >{ + /// The shape of the instruction. + typedef Shape<128, 8, 8> InstructionShape; + /// The number of threads per warp. That's a dummy configuration. + typedef Shape<1, 4, 8> ThreadsPerWarp; + /// Dimensions of the warp-level GEMM (K-by-N-by-M) + typedef WarpGemmShape_ WarpGemmShape; + /// Aliased for compatibility. Will be removed in CUTLASS v2.0 + typedef WarpGemmShape_ AccumulatorsPerWarp; + /// The type for A. + typedef Vector ScalarA; + /// The type for B. + typedef Vector ScalarB; + /// The type for C and D. + typedef int ScalarC; + /// The number of iterations. + typedef typename ShapeDiv::Shape Iterations; + + /// The element for A. + typedef WmmaMatrix, + InstructionShape> ElementA; + /// The fragment for A. + typedef Fragment FragmentA; + + /// The element for B. + typedef WmmaMatrix, + InstructionShape> ElementB; + /// The fragment for B. + typedef Fragment FragmentB; + + /// The element for C. + typedef WmmaMatrix ElementC; + /// The fragment for C. + typedef Fragment Accumulators; + + /// Ctor. + CUTLASS_DEVICE WmmaGemmMultiplyAdd() {} + + /// Multiply : d = a*b. + CUTLASS_DEVICE void multiply_add(FragmentA const& a, + FragmentB const& b, + Accumulators const& c, + Accumulators& d) { + for (int j = 0; j < Iterations::kH; ++j) { + for (int i = 0; i < Iterations::kW; ++i) { + // The input elements. + ElementA const& elt_a = a[i]; + ElementB const& elt_b = b[j]; + ElementC const& elt_c = c[j * Iterations::kW + i]; + + // The output element. + ElementC& elt_d = d[j * Iterations::kW + i]; + + // The wmma instruction. + nvcuda::wmma::bmma_sync(elt_d, + elt_a, + elt_b, + elt_c, + nvcuda::wmma::experimental::bmmaBitOpXOR, + nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); + } + } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef CUTLASS_USE_SUBBYTE_WMMA +/// Specialization for WMMA GEMM with signed 4-bit integer operands +template +struct WmmaGemmMultiplyAdd , + MatrixLayout::kColumnMajor, + Vector, + MatrixLayout::kColumnMajor, + int, + WarpGemmShape_, + Shape<32, 8, 8> >{ + /// The shape of the instruction. + typedef Shape<32, 8, 8> InstructionShape; + /// The number of threads per warp. That's a dummy configuration. + typedef Shape<1, 4, 8> ThreadsPerWarp; + /// Dimensions of the warp-level GEMM (K-by-N-by-M) + typedef WarpGemmShape_ WarpGemmShape; + /// Aliased for compatibility. Will be removed in CUTLASS v2.0 + typedef WarpGemmShape_ AccumulatorsPerWarp; + /// The type for A. + typedef Vector ScalarA; + /// The type for B. + typedef Vector ScalarB; + /// The type for C and D. + typedef int ScalarC; + /// The number of iterations. + typedef typename ShapeDiv::Shape Iterations; + + /// The element for A. + typedef WmmaMatrix, + InstructionShape> ElementA; + /// The fragment for A. + typedef Fragment FragmentA; + + /// The element for B. + typedef WmmaMatrix, + InstructionShape> ElementB; + /// The fragment for B. + typedef Fragment FragmentB; + + /// The element for C. + typedef WmmaMatrix ElementC; + /// The fragment for C. + typedef Fragment Accumulators; + + /// Ctor. + CUTLASS_DEVICE WmmaGemmMultiplyAdd() {} + + /// Multiply : d = a*b. + CUTLASS_DEVICE void multiply_add(FragmentA const& a, + FragmentB const& b, + Accumulators const& c, + Accumulators& d) { + for (int j = 0; j < Iterations::kH; ++j) { + for (int i = 0; i < Iterations::kW; ++i) { + // The input elements. + ElementA const& elt_a = a[i]; + ElementB const& elt_b = b[j]; + ElementC const& elt_c = c[j * Iterations::kW + i]; + + // The output element. + ElementC& elt_d = d[j * Iterations::kW + i]; + + // The wmma instruction. + nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c); + } + } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef CUTLASS_USE_SUBBYTE_WMMA +/// Specialization for WMMA GEMM with unsigned 4-bit integer operands +template +struct WmmaGemmMultiplyAdd , + MatrixLayout::kColumnMajor, + Vector, + MatrixLayout::kColumnMajor, + int, + WarpGemmShape_, + Shape<32, 8, 8> >{ + /// The shape of the instruction. + typedef Shape<32, 8, 8> InstructionShape; + /// The number of threads per warp. That's a dummy configuration. + typedef Shape<1, 4, 8> ThreadsPerWarp; + /// Dimensions of the warp-level GEMM (K-by-N-by-M) + typedef WarpGemmShape_ WarpGemmShape; + /// Aliased for compatibility. Will be removed in CUTLASS v2.0 + typedef WarpGemmShape_ AccumulatorsPerWarp; + /// The type for A. + typedef Vector ScalarA; + /// The type for B. + typedef Vector ScalarB; + /// The type for C and D. + typedef int ScalarC; + /// The number of iterations. + typedef typename ShapeDiv::Shape Iterations; + + /// The element for A. + typedef WmmaMatrix, + InstructionShape> ElementA; + /// The fragment for A. + typedef Fragment FragmentA; + + /// The element for B. + typedef WmmaMatrix, + InstructionShape> ElementB; + /// The fragment for B. + typedef Fragment FragmentB; + + /// The element for C. + typedef WmmaMatrix ElementC; + /// The fragment for C. + typedef Fragment Accumulators; + + /// Ctor. + CUTLASS_DEVICE WmmaGemmMultiplyAdd() {} + + /// Multiply : d = a*b. + CUTLASS_DEVICE void multiply_add(FragmentA const& a, + FragmentB const& b, + Accumulators const& c, + Accumulators& d) { + for (int j = 0; j < Iterations::kH; ++j) { + for (int i = 0; i < Iterations::kW; ++i) { + // The input elements. + ElementA const& elt_a = a[i]; + ElementB const& elt_b = b[j]; + ElementC const& elt_c = c[j * Iterations::kW + i]; + + // The output element. + ElementC& elt_d = d[j * Iterations::kW + i]; + + // The wmma instruction. + nvcuda::wmma::mma_sync(elt_d, elt_a, elt_b, elt_c); + } + } + } +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace gemm } // namespace cutlass diff --git a/cutlass/gemm/wmma_gemm_shared_tile.h b/cutlass/gemm/wmma_gemm_shared_tile.h index 7d15b260f..1a90e2f10 100644 --- a/cutlass/gemm/wmma_gemm_shared_tile.h +++ b/cutlass/gemm/wmma_gemm_shared_tile.h @@ -28,18 +28,15 @@ */ #pragma once -#include +#include "cutlass/wmma_matrix.h" #ifdef CUTLASS_USE_WMMA_API -#include -#include +#include "cutlass/gemm/gemm_operand.h" +#include "cutlass/reshape_tile.h" namespace cutlass { namespace gemm { -template -struct Debug {}; - //////////////////////////////////////////////////////////////////////////////////////////////////// template +#include "cutlass/wmma_matrix.h" #ifdef CUTLASS_USE_WMMA_API -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#include "cutlass/convert.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_epilogue.h" +#include "cutlass/gemm/gemm_epilogue_traits.h" +#include "cutlass/gemm/gemm_global_tile.h" +#include "cutlass/gemm/gemm_shared_tile.h" +#include "cutlass/gemm/gemm_traits.h" +#include "cutlass/gemm/wmma_gemm_epilogue_traits.h" +#include "cutlass/gemm/wmma_gemm_global_tile.h" +#include "cutlass/gemm/wmma_gemm_multiply_add.h" namespace cutlass { namespace gemm { @@ -53,12 +53,16 @@ template < MatrixLayout::Kind kLayoutB_, /// The tile size for the GEMM KxNxM. typename OutputTile_, + /// The input type. + typename ScalarA_, + /// The input type. + typename ScalarB_, /// The output type. typename ScalarC_, /// The accumulator type. typename Accumulator_, - /// The number of accumulators per warp. - typename AccumulatorsPerWarp_, + /// Tile size for warp-level GEMM (K-by-N-by-M) + typename WarpGemmShape_, /// The shape of the WMMA instruction. typename InstructionShape_, /// The number of scalars per LDG for A. @@ -67,9 +71,9 @@ template < int kScalarsPerLdgB_> struct WmmaGemmConfig : public GemmConfig< /// The scalar type for A. - half, + ScalarA_, /// The scalar type for B. - half, + ScalarB_, /// The scalar type for C. ScalarC_, /// The scalar type for D. @@ -78,12 +82,12 @@ struct WmmaGemmConfig : public GemmConfig< OutputTile_, /// The functor to do the math in the main loop. WmmaGemmMultiplyAdd, /// The number of scalars per LDG for A. kScalarsPerLdgA_, @@ -100,21 +104,29 @@ struct WmmaGemmConfig : public GemmConfig< /// The number of scalars per LDG for C and STG for D. 16 / sizeof(ScalarC_), /// The number of scalars per STS for D. - 16 / sizeof(ScalarC_), + 16 / sizeof(Accumulator_), /// The number of scalars per LDS for D. - 16 / sizeof(ScalarC_), + 16 / sizeof(Accumulator_), /// The number of stages in shared memory. - 1> {}; + 1, + /// If true, residue is computed in mainloop. If false, separate loops are instantiated. + false, + /// Is residue performed in prologue? + true, + /// If true, kernel is launched with CUDA launch bounds specified + false> {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct WmmaGemmTileTraitsHelperA {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct WmmaGemmTileTraitsHelperA +template +struct WmmaGemmTileTraitsHelperA : public GemmTileTraitsHelperA { /// The base config. typedef GemmTileTraitsHelperA Base; @@ -173,8 +185,8 @@ struct WmmaGemmTileTraitsHelperA //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct WmmaGemmTileTraitsHelperA { +template +struct WmmaGemmTileTraitsHelperA { /// The layout. static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor; @@ -251,13 +263,276 @@ struct WmmaGemmTileTraitsHelperA { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +#ifdef CUTLASS_USE_SUBBYTE_WMMA +/// Specialization for WMMA GEMM with binary operands +template +struct WmmaGemmTileTraitsHelperA > { + /// The layout. + static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor; + + /// The input scalar. + typedef typename GemmConfig_::ScalarA Scalar; + /// The scalar stored in shared memory. + typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar; + + /// GemmConfig_::OutputTile::kD is in number of 'bits'. TileTraits expects number of 'Scalar'. + /// Divide by 'kBitsPerScalar' to get the number in 'Scalar'. + static int const kBitsPerScalar = sizeof(Scalar) * 8; + + /// WMMA matrix + typedef WmmaMatrix, + typename GemmConfig_::InstructionShape> + WmmaMatrix; + + /// The traits class to build the iterator to load data from global memory for A^T. + typedef GemmGlobalTileTraits< + // That's A. + GemmOperand::kA, + // A is row-major. + MatrixLayout::kRowMajor, + // The pointer is float const. + Scalar const, + // The tile has size KxM in GEMM's terminology. + Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kBitsPerScalar>, + // The threads are distributed as warps x 32 (the traits may reorganize). + Shape<1, + GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kBitsPerScalar), + GemmConfig_::OutputTile::kD / kBitsPerScalar>, + // The number of scalars per LDG (LDG.32 or LDG.128, etc). + GemmConfig_::kScalarsPerLdgA / kBitsPerScalar> + GlobalTileTraits; + + /// The skew. + static int const kSkew = 16 / sizeof(MultiplyAddScalar); + /// The tile. + typedef Shape + Tile; + + /// The traits class to build the iterator to store data to shared memory for A^N. + typedef GemmSharedStoreTileAbTraits< + // The pointer. + MultiplyAddScalar, + // The tile has size KxM in GEMM's terminology. + Tile, + // The threads are distributed as warps x 32 (the traits may reorganize). + typename GlobalTileTraits::Threads, + // The number of scalars per STS (STS.32 or STS.128, etc). + GemmConfig_::kScalarsPerStsA / kBitsPerScalar> + SharedStoreTileTraits; + + /// The number of elements loaded in one LDG. + static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW; + /// The traits class to build the iterator to load from shared memory for A. + typedef WmmaGemmSharedLoadTileATraits< + // The layout of the matrix. + MatrixLayout::kRowMajor, + // The pointer. + MultiplyAddScalar, + // The tile in shared memory. + Tile, + // The number of warps. + typename GemmConfig_::Warps, + // The strides between warps. + GemmConfig_::InstructionShape::kW * Tile::kW, + // The number of iterations to load the data. + Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>, + // The stride between iterations. + Shape, + // The shape of the instruction. + typename GemmConfig_::InstructionShape> + SharedLoadTileTraits; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef CUTLASS_USE_SUBBYTE_WMMA +/// Specialization for WMMA GEMM with unsigned 4-bit integer operands +template +struct WmmaGemmTileTraitsHelperA > { + /// The layout. + static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor; + + /// The input scalar. + typedef typename GemmConfig_::ScalarA Scalar; + /// The scalar stored in shared memory. + typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar; + + /// GemmConfig_::OutputTile::kD is in number of 'int4'. TileTraits expects number of 'Scalar'. + /// Divide by 'kInt4PerScalar' to get the number in 'Scalar'. + static int const kInt4PerScalar = sizeof(Scalar) * 2; + + /// WMMA matrix + typedef WmmaMatrix, + typename GemmConfig_::InstructionShape> + WmmaMatrix; + + /// The traits class to build the iterator to load data from global memory for A^T. + typedef GemmGlobalTileTraits< + // That's A. + GemmOperand::kA, + // A is row-major. + MatrixLayout::kRowMajor, + // The pointer is float const. + Scalar const, + // The tile has size KxM in GEMM's terminology. + Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kInt4PerScalar>, + // The threads are distributed as warps x 32 (the traits may reorganize). + Shape<1, + GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar), + GemmConfig_::OutputTile::kD / kInt4PerScalar>, + // The number of scalars per LDG (LDG.32 or LDG.128, etc). + GemmConfig_::kScalarsPerLdgA / kInt4PerScalar> + GlobalTileTraits; + + /// The skew. + static int const kSkew = 16 / sizeof(MultiplyAddScalar); + /// The tile. + typedef Shape + Tile; + + /// The traits class to build the iterator to store data to shared memory for A^N. + typedef GemmSharedStoreTileAbTraits< + // The pointer. + MultiplyAddScalar, + // The tile has size KxM in GEMM's terminology. + Tile, + // The threads are distributed as warps x 32 (the traits may reorganize). + typename GlobalTileTraits::Threads, + // The number of scalars per STS (STS.32 or STS.128, etc). + GemmConfig_::kScalarsPerStsA / kInt4PerScalar> + SharedStoreTileTraits; + + /// The number of elements loaded in one LDG. + static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW; + /// The traits class to build the iterator to load from shared memory for A. + typedef WmmaGemmSharedLoadTileATraits< + // The layout of the matrix. + MatrixLayout::kRowMajor, + // The pointer. + MultiplyAddScalar, + // The tile in shared memory. + Tile, + // The number of warps. + typename GemmConfig_::Warps, + // The strides between warps. + GemmConfig_::InstructionShape::kW * Tile::kW, + // The number of iterations to load the data. + Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>, + // The stride between iterations. + Shape, + // The shape of the instruction. + typename GemmConfig_::InstructionShape> + SharedLoadTileTraits; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef CUTLASS_USE_SUBBYTE_WMMA +/// Specialization for WMMA GEMM with signed 4-bit integer operands +template +struct WmmaGemmTileTraitsHelperA > { + /// The layout. + static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor; + + /// The input scalar. + typedef typename GemmConfig_::ScalarA Scalar; + /// The scalar stored in shared memory. + typedef typename GemmConfig_::MultiplyAdd::ScalarA MultiplyAddScalar; + + /// GemmConfig_::OutputTile::kD is in number of 'int4'. TileTraits expects number of 'Scalar'. + /// Divide by 'kInt4PerScalar' to get the number in 'Scalar'. + static int const kInt4PerScalar = sizeof(Scalar) * 2; + + /// WMMA matrix + typedef WmmaMatrix, + typename GemmConfig_::InstructionShape> + WmmaMatrix; + + /// The traits class to build the iterator to load data from global memory for A^T. + typedef GemmGlobalTileTraits< + // That's A. + GemmOperand::kA, + // A is row-major. + MatrixLayout::kRowMajor, + // The pointer is float const. + Scalar const, + // The tile has size KxM in GEMM's terminology. + Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD / kInt4PerScalar>, + // The threads are distributed as warps x 32 (the traits may reorganize). + Shape<1, + GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar), + GemmConfig_::OutputTile::kD / kInt4PerScalar>, + // The number of scalars per LDG (LDG.32 or LDG.128, etc). + GemmConfig_::kScalarsPerLdgA / kInt4PerScalar> + GlobalTileTraits; + + /// The skew. + static int const kSkew = 16 / sizeof(MultiplyAddScalar); + /// The tile. + typedef Shape + Tile; + + /// The traits class to build the iterator to store data to shared memory for A^N. + typedef GemmSharedStoreTileAbTraits< + // The pointer. + MultiplyAddScalar, + // The tile has size KxM in GEMM's terminology. + Tile, + // The threads are distributed as warps x 32 (the traits may reorganize). + typename GlobalTileTraits::Threads, + // The number of scalars per STS (STS.32 or STS.128, etc). + GemmConfig_::kScalarsPerStsA / kInt4PerScalar> + SharedStoreTileTraits; + + /// The number of elements loaded in one LDG. + static int const kScalarsPerW = GemmConfig_::InstructionShape::kW * GemmConfig_::Warps::kW; + /// The traits class to build the iterator to load from shared memory for A. + typedef WmmaGemmSharedLoadTileATraits< + // The layout of the matrix. + MatrixLayout::kRowMajor, + // The pointer. + MultiplyAddScalar, + // The tile in shared memory. + Tile, + // The number of warps. + typename GemmConfig_::Warps, + // The strides between warps. + GemmConfig_::InstructionShape::kW * Tile::kW, + // The number of iterations to load the data. + Shape<1, 1, GemmConfig_::OutputTile::kW / kScalarsPerW>, + // The stride between iterations. + Shape, + // The shape of the instruction. + typename GemmConfig_::InstructionShape> + SharedLoadTileTraits; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct WmmaGemmTileTraitsHelperB {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct WmmaGemmTileTraitsHelperB +template +struct WmmaGemmTileTraitsHelperB : public GemmTileTraitsHelperB { /// The base config. typedef GemmTileTraitsHelperB Base; @@ -316,8 +591,8 @@ struct WmmaGemmTileTraitsHelperB //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct WmmaGemmTileTraitsHelperB { +template +struct WmmaGemmTileTraitsHelperB { /// The layout. static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor; @@ -394,6 +669,267 @@ struct WmmaGemmTileTraitsHelperB { //////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef CUTLASS_USE_SUBBYTE_WMMA +/// Specialization for WMMA GEMM with binary operands +template +struct WmmaGemmTileTraitsHelperB > { + /// The layout. + static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor; + + /// The input scalar. + typedef typename GemmConfig_::ScalarB Scalar; + /// The scalar stored in shared memory. + typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar; + + /// GemmConfig_::OutputTile::kD is in number of 'bits'. TileTraits expects number of 'Scalar'. + /// Divide by 'kBitsPerScalar' to get the number in 'Scalar'. + static int const kBitsPerScalar = sizeof(Scalar) * 8; + + /// WMMA matrix + typedef WmmaMatrix, + typename GemmConfig_::InstructionShape> + WmmaMatrix; + + /// The traits class to build the iterator to load data from global memory for B^N. + typedef GemmGlobalTileTraits< + // That's B. + GemmOperand::kB, + // A is row-major. + MatrixLayout::kColumnMajor, + // The pointer is float const. + Scalar const, + // The tile has size KxM in GEMM's terminology. + Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kBitsPerScalar>, + // The threads are distributed as warps x 32 (the traits may reorganize). + Shape<1, + GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kBitsPerScalar), + GemmConfig_::OutputTile::kD / kBitsPerScalar>, + // The number of scalars per LDG (LDG.32 or LDG.128, etc). + GemmConfig_::kScalarsPerLdgB / kBitsPerScalar> + GlobalTileTraits; + + /// The skew. + static int const kSkew = 16 / sizeof(MultiplyAddScalar); + /// The tile. + typedef Shape + Tile; + + /// The traits class to build the iterator to store data to shared memory for B^N. + typedef GemmSharedStoreTileAbTraits< + // The pointer. + MultiplyAddScalar, + // The tile has size KxM in GEMM's terminology. + Tile, + // The threads are distributed as warps x 32 (the traits may reorganize). + typename GlobalTileTraits::Threads, + // The number of scalars per STS (STS.32 or STS.128, etc). + GemmConfig_::kScalarsPerStsB / kBitsPerScalar> + SharedStoreTileTraits; + + /// The number of elements loaded in one LDG. + static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH; + /// The traits class to build the iterator to load from shared memory for B. + typedef WmmaGemmSharedLoadTileBTraits< + // The layout of the matrix. + MatrixLayout::kColumnMajor, + // The pointer. + MultiplyAddScalar, + // The tile in shared memory. + Tile, + // The number of warps. + typename GemmConfig_::Warps, + // The strides between warps. + GemmConfig_::InstructionShape::kH * Tile::kW, + // The number of iterations to load the data. + Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>, + // The stride between iterations. + Shape, + // The shape of the instruction. + typename GemmConfig_::InstructionShape> + SharedLoadTileTraits; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef CUTLASS_USE_SUBBYTE_WMMA +/// Specialization for WMMA GEMM with unsigned 4-bit integer operands +template +struct WmmaGemmTileTraitsHelperB > { + /// The layout. + static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor; + + /// The input scalar. + typedef typename GemmConfig_::ScalarB Scalar; + /// The scalar stored in shared memory. + typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar; + + /// GemmConfig_::OutputTile::kD is in number of 'int4'. TileTraits expects number of 'Scalar'. + /// Divide by 'kInt4PerScalar' to get the number in 'Scalar'. + static int const kInt4PerScalar = sizeof(Scalar) * 2; + + /// WMMA matrix + typedef WmmaMatrix, + typename GemmConfig_::InstructionShape> + WmmaMatrix; + + /// The traits class to build the iterator to load data from global memory for B^N. + typedef GemmGlobalTileTraits< + // That's B. + GemmOperand::kB, + // A is row-major. + MatrixLayout::kColumnMajor, + // The pointer is float const. + Scalar const, + // The tile has size KxM in GEMM's terminology. + Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kInt4PerScalar>, + // The threads are distributed as warps x 32 (the traits may reorganize). + Shape<1, + GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar), + GemmConfig_::OutputTile::kD / kInt4PerScalar>, + // The number of scalars per LDG (LDG.32 or LDG.128, etc). + GemmConfig_::kScalarsPerLdgB / kInt4PerScalar> + GlobalTileTraits; + + /// The skew. + static int const kSkew = 16 / sizeof(MultiplyAddScalar); + /// The tile. + typedef Shape + Tile; + + /// The traits class to build the iterator to store data to shared memory for B^N. + typedef GemmSharedStoreTileAbTraits< + // The pointer. + MultiplyAddScalar, + // The tile has size KxM in GEMM's terminology. + Tile, + // The threads are distributed as warps x 32 (the traits may reorganize). + typename GlobalTileTraits::Threads, + // The number of scalars per STS (STS.32 or STS.128, etc). + GemmConfig_::kScalarsPerStsB / kInt4PerScalar> + SharedStoreTileTraits; + + /// The number of elements loaded in one LDG. + static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH; + /// The traits class to build the iterator to load from shared memory for B. + typedef WmmaGemmSharedLoadTileBTraits< + // The layout of the matrix. + MatrixLayout::kColumnMajor, + // The pointer. + MultiplyAddScalar, + // The tile in shared memory. + Tile, + // The number of warps. + typename GemmConfig_::Warps, + // The strides between warps. + GemmConfig_::InstructionShape::kH * Tile::kW, + // The number of iterations to load the data. + Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>, + // The stride between iterations. + Shape, + // The shape of the instruction. + typename GemmConfig_::InstructionShape> + SharedLoadTileTraits; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef CUTLASS_USE_SUBBYTE_WMMA +/// Specialization for WMMA GEMM with signed 4-bit integer operands +template +struct WmmaGemmTileTraitsHelperB > { + /// The layout. + static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor; + + /// The input scalar. + typedef typename GemmConfig_::ScalarB Scalar; + /// The scalar stored in shared memory. + typedef typename GemmConfig_::MultiplyAdd::ScalarB MultiplyAddScalar; + + /// GemmConfig_::OutputTile::kD is in number of 'int4'. TileTraits expects number of 'Scalar'. + /// Divide by 'kInt4PerScalar' to get the number in 'Scalar'. + static int const kInt4PerScalar = sizeof(Scalar) * 2; + + /// WMMA matrix + typedef WmmaMatrix, + typename GemmConfig_::InstructionShape> + WmmaMatrix; + + /// The traits class to build the iterator to load data from global memory for B^N. + typedef GemmGlobalTileTraits< + // That's B. + GemmOperand::kB, + // A is row-major. + MatrixLayout::kColumnMajor, + // The pointer is float const. + Scalar const, + // The tile has size KxM in GEMM's terminology. + Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD / kInt4PerScalar>, + // The threads are distributed as warps x 32 (the traits may reorganize). + Shape<1, + GemmConfig_::kThreads / (GemmConfig_::OutputTile::kD / kInt4PerScalar), + GemmConfig_::OutputTile::kD / kInt4PerScalar>, + // The number of scalars per LDG (LDG.32 or LDG.128, etc). + GemmConfig_::kScalarsPerLdgB / kInt4PerScalar> + GlobalTileTraits; + + /// The skew. + static int const kSkew = 16 / sizeof(MultiplyAddScalar); + /// The tile. + typedef Shape + Tile; + + /// The traits class to build the iterator to store data to shared memory for B^N. + typedef GemmSharedStoreTileAbTraits< + // The pointer. + MultiplyAddScalar, + // The tile has size KxM in GEMM's terminology. + Tile, + // The threads are distributed as warps x 32 (the traits may reorganize). + typename GlobalTileTraits::Threads, + // The number of scalars per STS (STS.32 or STS.128, etc). + GemmConfig_::kScalarsPerStsB / kInt4PerScalar> + SharedStoreTileTraits; + + /// The number of elements loaded in one LDG. + static int const kScalarsPerW = GemmConfig_::InstructionShape::kH * GemmConfig_::Warps::kH; + /// The traits class to build the iterator to load from shared memory for B. + typedef WmmaGemmSharedLoadTileBTraits< + // The layout of the matrix. + MatrixLayout::kColumnMajor, + // The pointer. + MultiplyAddScalar, + // The tile in shared memory. + Tile, + // The number of warps. + typename GemmConfig_::Warps, + // The strides between warps. + GemmConfig_::InstructionShape::kH * Tile::kW, + // The number of iterations to load the data. + Shape<1, 1, GemmConfig_::OutputTile::kH / kScalarsPerW>, + // The stride between iterations. + Shape, + // The shape of the instruction. + typename GemmConfig_::InstructionShape> + SharedLoadTileTraits; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template < /// The layout for A. MatrixLayout::Kind kLayoutA_, @@ -401,14 +937,18 @@ template < MatrixLayout::Kind kLayoutB_, /// The output tile. typename OutputTile_, + /// The input type. + typename ScalarA_, + /// The input type. + typename ScalarB_, /// The output type. typename ScalarC_, /// The accumulator type. typename Accumulator_, /// The functor to do the math in the epilogue. typename EpilogueFunctor_, - /// The number of accumulators per warp. - typename AccumulatorsPerWarp_, + /// Tile size for warp-level GEMM (K-by-N-by-M) + typename WarpGemmShape_, /// The shape of the WMMA instruction. typename InstructionShape_, /// The number of halfs loaded in one LDG for A. @@ -422,18 +962,20 @@ struct WmmaGemmTraitsHelper { typedef WmmaGemmConfig GemmConfig; /// The GEMM config for A. - typedef WmmaGemmTileTraitsHelperA GemmTileTraitsHelperA; + typedef WmmaGemmTileTraitsHelperA GemmTileTraitsHelperA; /// The GEMM config for B. - typedef WmmaGemmTileTraitsHelperB GemmTileTraitsHelperB; + typedef WmmaGemmTileTraitsHelperB GemmTileTraitsHelperB; /// The iterator to load A from global memory. typedef GemmGlobalIteratorAb @@ -447,7 +989,10 @@ struct WmmaGemmTraitsHelper { MemorySpace::kShared> SharedStoreIteratorA; /// The stream to load A from global memory to shared memory. - typedef GlobalLoadStream + typedef GlobalLoadStream GlobalLoadStreamA; /// The iterator to load B from global memory. @@ -462,7 +1007,10 @@ struct WmmaGemmTraitsHelper { MemorySpace::kShared> SharedStoreIteratorB; /// The stream to load B from global memory to shared memory. - typedef GlobalLoadStream + typedef GlobalLoadStream GlobalLoadStreamB; /// The iterator to load A from shared memory. @@ -472,7 +1020,7 @@ struct WmmaGemmTraitsHelper { MemorySpace::kShared, Index_, typename GemmTileTraitsHelperA::WmmaMatrix, - IteratorFragment::kWmmaMatrix> + FragmentElementType::kWmmaMatrix> SharedLoadIteratorA; /// The stream to load A from shared memory. typedef SharedLoadStream SharedLoadStreamA; @@ -483,7 +1031,7 @@ struct WmmaGemmTraitsHelper { MemorySpace::kShared, Index_, typename GemmTileTraitsHelperB::WmmaMatrix, - IteratorFragment::kWmmaMatrix> + FragmentElementType::kWmmaMatrix> SharedLoadIteratorB; /// The stream to load B from shared memory. typedef SharedLoadStream SharedLoadStreamB; @@ -518,14 +1066,18 @@ template < MatrixLayout::Kind kLayoutB_, /// The tile size for the GEMM KxNxM. typename OutputTile_ = Shape<64, 128, 128>, + /// The input type. + typename ScalarA_ = half, + /// The input type. + typename ScalarB_ = half, /// The output type. typename ScalarC_ = float, /// The functor to do the math in the epilogue. typename EpilogueFunctor_ = LinearScaling, /// The accumulator type. typename Accumulator_ = ScalarC_, - /// The number of accumulators per warp. - typename AccumulatorsPerWarp_ = typename WmmaGemmAccumulatorsPerWarp::Shape, + /// Tile size for warp-level GEMM (K-by-N-by-M) + typename WarpGemmShape_ = typename WmmaGemmAccumulatorsPerWarp::Shape, /// The shape of the WMMA instruction. typename InstructionShape_ = Shape<16, 16, 16>, /// The number of scalars per LDG for A. @@ -538,10 +1090,12 @@ template < typename Helper_ = WmmaGemmTraitsHelper -#include -#include -#include +#include "cutlass/load_store.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/shape.h" namespace cutlass { /////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Loads a fragment from an input iterator +// Used by convolution template CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment) { typename InputIterator::FragmentIterator frag_iterator(fragment); @@ -45,12 +43,12 @@ CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragme for (int w = 0; w < InputIterator::Iterations::kW; ++w) { for (int c = 0; c < InputIterator::Iterations::kC; ++c) { if (iterator.valid(d, h, w, c)) { - iterator.get(reinterpret_cast( - frag_iterator.at(d, h, w, c)), - d, - h, - w, - c); + iterator.load_element(reinterpret_cast( + frag_iterator.at(d, h, w, c)), + d, + h, + w, + c); } } if (w < InputIterator::Iterations::kW - 1) { @@ -68,138 +66,21 @@ CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragme iterator.inc_advance(); } -/// Loads a fragment from a shared memory input iterator -template -CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment) { - typename InputIterator::FragmentIterator frag_iterator(fragment); - for (int d = 0; d < InputIterator::Iterations::kD; ++d) { - for (int h = 0; h < InputIterator::Iterations::kH; ++h) { - for (int w = 0; w < InputIterator::Iterations::kW; ++w) { - for (int c = 0; c < InputIterator::Iterations::kC; ++c) { - int const offset = - ComputeOffsetFromStrides::get( - d, h, w, c); - - FragmentLoad::load(frag_iterator.at(d, h, w, c), - iterator.data(), - offset); - } - } - } - } -} - -/// Loads a fragment from a shared memory input iterator -template -CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment, int d) { - typename InputIterator::FragmentIterator frag_iterator(fragment); - for (int h = 0; h < InputIterator::Iterations::kH; ++h) { - for (int w = 0; w < InputIterator::Iterations::kW; ++w) { - for (int c = 0; c < InputIterator::Iterations::kC; ++c) { - int const offset = - ComputeOffsetFromStrides::get( - d, h, w, c); - - FragmentLoad::load(frag_iterator.at(0, h, w, c), - iterator.data(), - offset); - } - } - } -} - -/// Loads a fragment from an input iterator, masked by a predicate iterator -template -CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator, - Fragment &fragment, - typename InputIterator::Index offset, - ConstPredicateAdapter predicate_adapter) { - for (int d = 0; d < InputIterator::Iterations::kD; ++d, iterator.inc_d()) { - for (int h = 0; h < InputIterator::Iterations::kH; ++h, iterator.inc_h()) { - for (int w = 0; w < InputIterator::Iterations::kW; ++w, iterator.inc_w()) { - if (predicate_adapter.at(d, h, w, 0)) { - int idx = InputIterator::Tile::kC * - (w + InputIterator::Iterations::kW * (h + InputIterator::Iterations::kH * d)); - - Load:: - load(reinterpret_cast(fragment[idx]), - iterator.data(), - offset); - } - } - } - } -} - -/// Loads a fragment from an input iterator -template -CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator, - Fragment &fragment, - typename InputIterator::Index offset = 0) { - TrivialPredicateTileAdapter pred; - iterator_load_post_increment(iterator, fragment, offset, pred); -} - -/// Loads a fragment from an input iterator -template -CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator, - Fragment &fragment, - ConstPredicateAdapter pred_it) { - iterator_load_post_increment(iterator, fragment, 0, pred_it); -} - -template -CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &_iterator, - Fragment &fragment, - typename InputIterator::Index offset, - ConstPredicateAdapter predicate_adapter) { - InputIterator iterator(_iterator); - iterator_load_post_increment(iterator, fragment, offset, predicate_adapter); -} - -/// Loads a fragment from an input iterator -template -CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &iterator, - Fragment &fragment, - typename InputIterator::Index offset = 0) { - TrivialPredicateTileAdapter pred; - iterator_load(iterator, fragment, offset, pred); -} - -/// Loads a fragment from an input iterator -template -CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &iterator, - Fragment &fragment, - ConstPredicateAdapter pred_it) { - iterator_load(iterator, fragment, 0, pred_it); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - -/// Stores a fragment to an output iterator template CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment) { typename OutputIterator::FragmentIterator frag_iterator(fragment); for (int d = 0; d < OutputIterator::Iterations::kD; ++d) { for (int h = 0; h < OutputIterator::Iterations::kH; ++h) { for (int w = 0; w < OutputIterator::Iterations::kW; ++w) { - if (iterator.valid(d, h, w, 0)) { - iterator.set(reinterpret_cast( - frag_iterator.at(d, h, w, 0)), - d, - h, - w, - 0); + for (int c = 0; c < OutputIterator::Iterations::kC; ++c) { + if (iterator.valid(d, h, w, c)) { + iterator.store_element(reinterpret_cast( + frag_iterator.at(d, h, w, c)), + d, + h, + w, + c); + } } if (w < OutputIterator::Iterations::kW - 1) { iterator.inc_w(); @@ -215,104 +96,6 @@ CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &frag } iterator.inc_advance(); } - -/// Stores a fragment to a shared memory output iterator -template -CUTLASS_DEVICE void shared_iterator_store(OutputIterator &iterator, Fragment const &fragment) { - typename OutputIterator::FragmentConstIterator frag_iterator(fragment); - for (int d = 0; d < OutputIterator::Iterations::kD; ++d) { - for (int h = 0; h < OutputIterator::Iterations::kH; ++h) { - for (int w = 0; w < OutputIterator::Iterations::kW; ++w) { - for (int c = 0; c < OutputIterator::Iterations::kC; ++c) { - int const offset = - ComputeOffsetFromStrides::get( - d, h, w, c); - - FragmentStore::store(frag_iterator.at(d, h, w, c), - iterator.data(), - offset); - } - } - } - } -} - //////////////////////////////////////////////////////////////////////////////////////////////////// -/// Stores a fragment to an output iterator, masked by a predicate iterator -template -CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator, - Fragment const &fragment, - typename OutputIterator::Index offset, - ConstPredicateAdapter predicate_adapter) { - for (int d = 0; d < OutputIterator::Iterations::kD; ++d, iterator.inc_d()) { - for (int h = 0; h < OutputIterator::Iterations::kH; ++h, iterator.inc_h()) { - for (int w = 0; w < OutputIterator::Iterations::kW; ++w, iterator.inc_w()) { - if (predicate_adapter.at(d, h, w, 0)) { - int idx = OutputIterator::Tile::kC * - (w + OutputIterator::Iterations::kW * (h + OutputIterator::Iterations::kH * d)); - - Store:: - store(reinterpret_cast(fragment[idx]), - iterator.data(), - offset); - } - } - } - } -} - -/// Stores a fragment to an output iterator -template -CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator, - Fragment const &fragment, - typename OutputIterator::Index offset = 0) { - TrivialPredicateTileAdapter pred; - iterator_store_post_increment(iterator, fragment, offset, pred); -} - -/// Stores a fragment to an output iterator -template -CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator, - Fragment const &fragment, - ConstPredicateAdapter pred_it) { - iterator_store_post_increment(iterator, fragment, 0, pred_it); -} - -/// Stores a fragment to an output iterator, masked by a predicate iterator -template -CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &_iterator, - Fragment const &fragment, - typename OutputIterator::Index offset, - ConstPredicateAdapter predicate_adapter) { - OutputIterator iterator(_iterator); - iterator_store_post_increment(iterator, fragment, offset, predicate_adapter); -} - -/// Stores a fragment to an output iterator -template -CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &iterator, - Fragment const &fragment, - typename OutputIterator::Index offset = 0) { - TrivialPredicateTileAdapter pred; - iterator_store(iterator, fragment, offset, pred); -} - -/// Stores a fragment to an output iterator -template -CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &iterator, - Fragment const &fragment, - ConstPredicateAdapter pred_it) { - iterator_store(iterator, fragment, 0, pred_it); -} - -/////////////////////////////////////////////////////////////////////////////////////////////////// - } // namespace cutlass diff --git a/cutlass/kernel_launch.h b/cutlass/kernel_launch.h new file mode 100644 index 000000000..ee37b2fda --- /dev/null +++ b/cutlass/kernel_launch.h @@ -0,0 +1,67 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines structures and helpers to launch CUDA kernels within CUTLASS. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure containing the basic launch configuration of a CUDA kernel. +struct KernelLaunchConfiguration { + + /// CUDA grid dimensions + dim3 grid; + + /// CUDA threablock dimensions + dim3 block; + + /// Bytes of dynamically allocated SMEM in addition to static SMEM + size_t dynamic_smem; + + // + // Methods + // + + /// Constructs a KernellaunchConfiguration object + CUTLASS_HOST_DEVICE + KernelLaunchConfiguration( + dim3 _grid = dim3(1,1,1), + dim3 _block = dim3(1,1,1), + size_t _dynamic_smem = 0 + ): + grid(_grid), + block(_block), + dynamic_smem(_dynamic_smem) { } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/cutlass/load_store.h b/cutlass/load_store.h index 5cb5eb672..db09dd0a4 100644 --- a/cutlass/load_store.h +++ b/cutlass/load_store.h @@ -27,8 +27,7 @@ */ #pragma once -#include - +#include "cutlass/vector.h" namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -44,45 +43,68 @@ struct MemorySpace { }; }; +/// Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix +struct FragmentElementType { + enum Kind { kScalar, kWmmaMatrix }; +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// template 1), - size_t = (sizeof(Scalar_) * Lanes_)> + FragmentElementType::Kind kFragmentElementType = FragmentElementType::kScalar, + typename FragmentElement_ = Scalar_, + int kStride = 1, + size_t size = (sizeof(Scalar_) * kAccessSize)> struct Load { /// The output type. - typedef typename Vectorize::Type AccessType; + typedef typename Vectorize::Type AccessType; /// The load function. - static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) { - dst = reinterpret_cast(&pointer[offset])[0]; + static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) { + dst = *reinterpret_cast(pointer + offset); + } + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for 16b loads +template +struct Load { + /// The output type. + typedef typename Vectorize::Type AccessType; + + /// The load function. + static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) { + reinterpret_cast(dst) = reinterpret_cast(&pointer[offset])[0]; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Load { +template +struct Load { /// The output type. - typedef typename Vectorize::Type AccessType; + typedef typename Vectorize::Type AccessType; - /// The store function. - static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) { + /// The load function. + static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) { dst.registers[0] = reinterpret_cast(&pointer[offset])[0]; } + }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Load { +template +struct Load { /// The output type. - typedef typename Vectorize::Type AccessType; + typedef typename Vectorize::Type AccessType; - /// The store function. - static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) { + /// The load function. + static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) { uint2 tmp = reinterpret_cast(&pointer[offset])[0]; dst.registers[0] = tmp.x; dst.registers[1] = tmp.y; @@ -91,13 +113,13 @@ struct Load { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Load { +template +struct Load { /// The output type. typedef typename Vectorize::Type AccessType; - /// The store function. - static CUTLASS_DEVICE void load(AccessType& dst, double const* pointer, int offset) { + /// The load function. + static CUTLASS_HOST_DEVICE void load(AccessType& dst, double const* pointer, int offset) { double2 tmp = reinterpret_cast(&pointer[offset])[0]; dst[0] = tmp.x; dst[1] = tmp.y; @@ -108,13 +130,13 @@ struct Load { #if defined(__CUDACC_VERSION_MAJOR) && __CUDACC_VERSION_MAJOR < 10 // WAR bug in NVCC where the upper and lower half of the register end up being the same -template -struct Load { +template +struct Load { /// The output type. typedef typename Vectorize::Type AccessType; - /// The store function. - static CUTLASS_DEVICE void load(AccessType& dst, half const* pointer, int offset) { + /// The load function. + static CUTLASS_HOST_DEVICE void load(AccessType& dst, half const* pointer, int offset) { int2 tmp = reinterpret_cast(&pointer[offset])[0]; dst.registers[0] = tmp.x; dst.registers[1] = tmp.y; @@ -129,13 +151,13 @@ struct Load { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Load { +template +struct Load { /// The output type. - typedef typename Vectorize::Type AccessType; + typedef typename Vectorize::Type AccessType; - /// The store function. - static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) { + /// The load function. + static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) { uint4 tmp = reinterpret_cast(&pointer[offset])[0]; dst.registers[0] = tmp.x; dst.registers[1] = tmp.y; @@ -147,29 +169,45 @@ struct Load { //////////////////////////////////////////////////////////////////////////////////////////////////// template 1), - size_t = (sizeof(Scalar_) * Lanes_)> + FragmentElementType::Kind kFragmentElementType = FragmentElementType::kScalar, + typename FragmentElement_ = Scalar_, + int kStride = 1, + size_t size = (sizeof(Scalar_) * kAccessSize)> struct Store { /// The output type. - typedef typename Vectorize::Type AccessType; + typedef typename Vectorize::Type AccessType; /// The store function. - static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) { - pointer[offset] = src; + static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) { + pointer[offset] = *reinterpret_cast(&src); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Store { +template +struct Store { /// The output type. - typedef typename Vectorize::Type AccessType; + typedef typename Vectorize::Type AccessType; /// The store function. - static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) { + static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) { + uint16_t* addr = reinterpret_cast(&pointer[offset]); + addr[0] = reinterpret_cast(src); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Store { + /// The output type. + typedef typename Vectorize::Type AccessType; + + /// The store function. + static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) { uint32_t* addr = reinterpret_cast(&pointer[offset]); addr[0] = src.registers[0]; } @@ -177,13 +215,13 @@ struct Store { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Store { +template +struct Store { /// The output type. - typedef typename Vectorize::Type AccessType; + typedef typename Vectorize::Type AccessType; /// The store function. - static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) { + static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) { uint2* addr = reinterpret_cast(&pointer[offset]); addr[0] = make_uint2(src.registers[0], src.registers[1]); } @@ -191,13 +229,13 @@ struct Store { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Store { +template +struct Store { /// The output type. typedef typename Vectorize::Type AccessType; /// The store function. - static CUTLASS_DEVICE void store(AccessType const& src, double* pointer, int offset) { + static CUTLASS_HOST_DEVICE void store(AccessType const& src, double* pointer, int offset) { double2* addr = reinterpret_cast(&pointer[offset]); addr[0] = make_double2(src[0], src[1]); } @@ -205,13 +243,13 @@ struct Store { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Store { +template +struct Store { /// The output type. - typedef typename Vectorize::Type AccessType; + typedef typename Vectorize::Type AccessType; /// The store function. - static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) { + static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) { uint4* addr = reinterpret_cast(&pointer[offset]); addr[0] = make_uint4(src.registers[0], src.registers[1], src.registers[2], src.registers[3]); } @@ -219,4 +257,123 @@ struct Store { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct Load { + /// The output type. + typedef FragmentElement_ AccessType; + + /// The load function. + static CUTLASS_HOST_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) { + value.load(&pointer[offset], kStride); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Load, + kAccessSize, + Memory_, + FragmentElementType::kWmmaMatrix, + FragmentElement_, + kStride, + size> { + /// The output type. + typedef FragmentElement_ AccessType; + + /// The load function. + static CUTLASS_HOST_DEVICE void load(AccessType& value, Vector const* pointer, + int offset) { + value.load(&pointer[offset], kStride * 32); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Load, + kAccessSize, + Memory_, + FragmentElementType::kWmmaMatrix, + FragmentElement_, + kStride, + size> { + /// The output type. + typedef FragmentElement_ AccessType; + + /// The load function. + static CUTLASS_HOST_DEVICE void load(AccessType& value, Vector const* pointer, + int offset) { + value.load(&pointer[offset], kStride * 8); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Load, + kAccessSize, + Memory_, + FragmentElementType::kWmmaMatrix, + FragmentElement_, + kStride, + size> { + /// The output type. + typedef FragmentElement_ AccessType; + + /// The load function. + static CUTLASS_HOST_DEVICE void load(AccessType& value, Vector const* pointer, + int offset) { + value.load(&pointer[offset], kStride * 8); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct Store { + /// The input type. + typedef FragmentElement_ AccessType; + + /// The store function. + static CUTLASS_HOST_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) { + value.store(&pointer[offset], kStride); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass diff --git a/cutlass/matrix_traits.h b/cutlass/matrix_traits.h index 77e8b7062..08a43a99a 100644 --- a/cutlass/matrix_traits.h +++ b/cutlass/matrix_traits.h @@ -27,13 +27,327 @@ */ #pragma once +#include "cutlass/coord.h" + namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// -/// Describes layouts of matrices +/// MatrixCoord wraps Coord<2, int> to provide a helper for accessing named dimensions. Classes +/// expecting a coordinate in the rank=2 index space of a matrix should use MatrixCoord. +struct MatrixCoord : public Coord<2, int> { + + /// Integer-valued index + typedef int Index; + + /// Base type is a Coord of rank=2 + typedef Coord<2, Index> Base; + + /// Rows dimension + static int const kRow = 0; + + /// Columns dimension + static int const kColumn = 1; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + MatrixCoord() { } + + /// Constructs from Coord<2> + CUTLASS_HOST_DEVICE + MatrixCoord(Coord<2, Index> const &coord): Base(coord) { } + + /// Helper to construct from a row and column + CUTLASS_HOST_DEVICE + MatrixCoord(Index row, Index column): Base(make_Coord(row, column)) { } + + /// Returns the row of the coordinate + CUTLASS_HOST_DEVICE + Index const & row() const { return this->at(kRow); } + + /// Returns the row of the coordinate + CUTLASS_HOST_DEVICE + Index & row() { return this->at(kRow); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index const & column() const { return this->at(kColumn); } + + /// Returns the column of the coordinate + CUTLASS_HOST_DEVICE + Index & column() { return this->at(kColumn); } + + // + // Coord operators + // + + /// Element-wise addition + CUTLASS_HOST_DEVICE + MatrixCoord operator+(Base const& b) const { + return MatrixCoord(Base::operator+(b)); + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + MatrixCoord operator-(Base const& b) const { + return MatrixCoord(Base::operator-(b)); + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + MatrixCoord operator*(Base const& b) const { + return MatrixCoord(Base::operator*(b)); + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + MatrixCoord operator/(Base const& b) const { + return MatrixCoord(Base::operator/(b)); + } + + /// In-place addition + CUTLASS_HOST_DEVICE + MatrixCoord& operator+=(Base const& b) { + Base::operator+=(b); + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + MatrixCoord& operator-=(Base const& b) { + Base::operator-=(b); + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + MatrixCoord& operator*=(Base const& b) { + Base::operator*=(b); + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + MatrixCoord& operator/=(Base const& b) { + Base::operator/=(b); + return *this; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines data layouts of various matrix formats usable by TensorRef and other classes. +// +// The following define classes satisfying the TensorRefMapFunc concept. These must support the +// following operations, where func is an instance of type TensorRefMapFunc. +// +// Coord = func(Coord); +// +// Though not required to be usable by TensorRef, each of the following also define a helper +// function to map the "leading dimension" to an appropriate stride vector. Implementations +// following this convention should also implement the following static method: +// +// Coord stride = TensorRefMapFunc::stride(leading_dim); +// struct MatrixLayout { + + /// Enumeration defining fundamental contiguous layouts. enum Kind { kRowMajor, kColumnMajor }; + + // + // TensorRefMapFunc definitions for common layouts + // + + /// Mapping function for row-major matrices + struct RowMajor { + static int const kStorageRank = 2; + /// Maps (i, j) to (i, j) + CUTLASS_HOST_DEVICE + Coord operator()(MatrixCoord const &coord) const { + return coord; + } + }; + + /// Mapping function for column-major matrices + struct ColumnMajor { + static int const kStorageRank = 2; + /// Maps (i, j) to (j, i) + CUTLASS_HOST_DEVICE + Coord operator()(MatrixCoord const &coord) const { + return make_Coord(coord.column(), coord.row()); + } + }; + + /// Mapping function for interleaved matrices. Matrix is structured + /// as row-major arrangement of fixed-size columns. + template + struct RowMajorInterleaved { + + /// Rank of storage n-D array + static int const kStorageRank = 3; + + /// Interleaving size + static int const kInterleave = Interleave; + + /// Maps (row, col) to (row, col, row) + CUTLASS_HOST_DEVICE + Coord operator()(MatrixCoord const &coord) const { + return make_Coord( + coord.row() / kInterleave, + coord.column(), + coord.row() % kInterleave + ); + } + + /// Helper to compute stride vector from leading dimension + CUTLASS_HOST_DEVICE + static Coord stride(int ldm) { + return make_Coord( + ldm * kInterleave, + kInterleave, + 1 + ); + } + }; + + /// Mapping function for interleaved matrices. Matrix is structured + /// as column-major arrangement of fixed-size rows. + template + struct ColumnMajorInterleaved { + + /// Rank of storage n-D array + static int const kStorageRank = 3; + + /// Interleaving size + static int const kInterleave = Interleave; + + /// Maps (row, col) to (col, row, col) + CUTLASS_HOST_DEVICE + Coord operator()(MatrixCoord const &coord) const { + return make_Coord( + coord.column() / kInterleave, + coord.row(), + coord.column() % kInterleave + ); + } + + /// Helper to compute stride vector from leading dimension + CUTLASS_HOST_DEVICE + static Coord stride(int ldm) { + return make_Coord( + ldm * kInterleave, + kInterleave, + 1 + ); + } + }; + + /// Mapping function for scenario in which layout is row-major or column-major but this information + /// is only available at runtime. + struct ContiguousLayout { + /// Arbitrary storage rank + static int const kStorageRank = 3; + + /// Dimension of rows + static int const kRow = 0; + + /// Dimension of columns + static int const kColumn = 1; + + /// Mapping function defined by runtime variable. Returns coordinates in n-D storage array + /// as (matrix row, matrix colum, 0) + CUTLASS_HOST_DEVICE + Coord operator()(MatrixCoord const &coord) const { + return make_Coord(coord.row(), coord.column(), 0); + } + + /// Helper to construct a stride vector based on contiguous matrix layout and leading dimension + CUTLASS_HOST_DEVICE + static Coord stride(MatrixLayout::Kind layout, int ldm) { + if (layout == MatrixLayout::kRowMajor) { + return make_Coord(ldm, 1, 1); + } + return make_Coord(1, ldm, 1); + } + }; + + /// Mapping function for block-linear matrices. Matrix is structured + /// as column-major arrangement of 2D tiles (that are column-major). + template + struct ColumnMajorBlockLinear { + + /// Rank of storage n-D array + static int const kStorageRank = 4; + + /// Interleaving size in rows dimension + static int const kBlockRows = BlockRows; + + /// Interleaving size in columns dimension + static int const kBlockColumns = BlockColumns; + + /// Maps (row, col) to (col, row, col, row) + CUTLASS_HOST_DEVICE + Coord operator()(MatrixCoord const &coord) const { + return make_Coord( + coord.column() / kBlockColumns, + coord.row() / kBlockRows, + coord.column() % kBlockColumns, + coord.row() % kBlockRows + ); + } + + /// Helper to compute stride vector from leading dimension + CUTLASS_HOST_DEVICE + static Coord stride(int ldm) { + return make_Coord( + ldm * kBlockRows * kBlockColumns, + kBlockRows * kBlockColumns, + kBlockRows, + 1 + ); + } + }; + + /// Mapping function for block-linear matrices. Matrix is structured + /// as row-major arrangement of 2D tiles (that are row-major) + template + struct RowMajorBlockLinear { + + /// Rank of storage n-D array + static int const kStorageRank = 4; + + /// Interleaving size in rows dimension + static int const kBlockRows = BlockRows; + + /// Interleaving size in columns dimension + static int const kBlockColumns = BlockColumns; + + /// Maps (row, col) to (row, col, row, col) + CUTLASS_HOST_DEVICE + Coord operator()(MatrixCoord const &coord) const { + return make_Coord( + coord.row() / kBlockRows, + coord.column() / kBlockColumns, + coord.row() % kBlockRows, + coord.column() % kBlockColumns + ); + } + + /// Helper to compute stride vector from leading dimension + CUTLASS_HOST_DEVICE + static Coord stride(int ldm) { + return make_Coord( + ldm * kBlockRows * kBlockColumns, + kBlockRows * kBlockColumns, + kBlockColumns, + 1 + ); + } + }; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -45,4 +359,14 @@ struct GemmOperand { //////////////////////////////////////////////////////////////////////////////////////////////////// +/// Transformation applied to matrix operands +struct MatrixTransform { + enum Kind { + kNone, /// no operation + kConjugate, /// conjugate + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass diff --git a/cutlass/predicate_vector.h b/cutlass/predicate_vector.h index 81668577e..4a37d017d 100644 --- a/cutlass/predicate_vector.h +++ b/cutlass/predicate_vector.h @@ -28,12 +28,13 @@ */ #pragma once +#include #include -#include -#include +#include "cutlass/cutlass.h" +#include "cutlass/shape.h" -#include +#include "cutlass/util/platform.h" namespace cutlass { @@ -114,7 +115,7 @@ struct PredicateVector { // Make sure no one tries to put more than 8 bits in a byte :) static_assert(kPredicatesPerByte <= 8, "kPredicatesPerByte must fit within an actual byte"); // Make sure the "offsetted" bits fit in one byte. - static_assert(kPredicateStart + kPredicatesPerByte < 8, + static_assert(kPredicateStart + kPredicatesPerByte <= 8, "The offsetted predicates must fit within an actual byte."); /// Storage type of individual elements diff --git a/cutlass/reshape_tile.h b/cutlass/reshape_tile.h index 55aebfcaf..67faa602a 100644 --- a/cutlass/reshape_tile.h +++ b/cutlass/reshape_tile.h @@ -27,7 +27,7 @@ */ #pragma once -#include +#include "cutlass/shape.h" namespace cutlass { diff --git a/cutlass/shape.h b/cutlass/shape.h index 4f6b222ee..b8c0c66f3 100644 --- a/cutlass/shape.h +++ b/cutlass/shape.h @@ -27,7 +27,7 @@ */ #pragma once -#include +#include "cutlass/cutlass.h" namespace cutlass { @@ -128,6 +128,17 @@ struct ShapeDiv { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ShapeDivCeiling { + typedef Shape<(A_::kD + B_::kD - 1) / B_::kD, + (A_::kH + B_::kH - 1) / B_::kH, + (A_::kW + B_::kW - 1) / B_::kW, + (A_::kC + B_::kC - 1) / B_::kC> + Shape; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template struct ShapeMax { typedef Shape<(A_::kD > B_::kD ? A_::kD : B_::kD), @@ -150,12 +161,12 @@ struct ShapeMin { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct ShapeStrides { typedef Shape + elementsPerAccess> Shape; }; @@ -167,7 +178,7 @@ struct ShapeStrides { */ template struct ComputeOffsetFromShape { - static CUTLASS_DEVICE int get(int d, int h, int w, int c) { + static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c) { // clang-format off return d * Shape_::kH * Shape_::kW * Shape_::kC + h * Shape_::kW * Shape_::kC + @@ -179,73 +190,19 @@ struct ComputeOffsetFromShape { //////////////////////////////////////////////////////////////////////////////////////////////////// -/** -* @brief Compute the offset for the given coordinates in a cube with a depth of 1 -* @tparam kSh Elements in the H dimension -* @tparam kSw Elements in the W dimension -* @tparam kSc Separation between two elements in "elements" -*/ -template -struct ComputeOffsetFromShape > { - static CUTLASS_DEVICE int get(int d, int h, int w, int c) { - return h * kSw_ * kSc_ + w * kSc_ + c; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/** -* @brief Compute the offset for the given coordinates in a cube with one channel and a depth of 1 -* @tparam kSh Elements in the H dimension -* @tparam kSw Elements in the W dimension -*/ -template -struct ComputeOffsetFromShape > { - static CUTLASS_DEVICE int get(int d, int h, int w, int c) { return h * kSw_ + w; } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - /** * @brief Compute the offset for the given coordinates in a cube * @tparam A \ref layout_concept where each dimension of the cube specifies the corresponding stride. */ template struct ComputeOffsetFromStrides { - static CUTLASS_DEVICE int get(int d, int h, int w, int c) { + static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c) { return d * Strides_::kD + h * Strides_::kH + w * Strides_::kW + c * Strides_::kC; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// -/** -* @brief Compute the offset for the given coordinates in a cube with a depth of 1 -* @tparam S_h Stride in the H dimension in scalars -* @tparam S_w Stride in the W dimension in scalars -* @tparam S_c Stride between two scalars. -*/ -template -struct ComputeOffsetFromStrides > { - static CUTLASS_DEVICE int get(int d, int h, int w, int c) { - return h * S_h_ + w * S_w_ + c * S_c_; - } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -/** -* @brief Compute the offset for the given coordinates in a cube with one channel and a depth of 1 -* @tparam S_h Stride in the H dimension in scalars -* @tparam S_w Stride in the W dimension in scalars -*/ -template -struct ComputeOffsetFromStrides > { - static CUTLASS_DEVICE int get(int d, int h, int w, int c) { return h * S_h_ + w * S_w_; } -}; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - /** * @brief Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_. * Afterwards compute the offset of those coordinates using Strides_ diff --git a/cutlass/tensor_ref.h b/cutlass/tensor_ref.h index 8ef31e3b8..09134190c 100644 --- a/cutlass/tensor_ref.h +++ b/cutlass/tensor_ref.h @@ -27,125 +27,613 @@ */ #pragma once -#include - -#include -#include -#include +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/vector.h" namespace cutlass { -//////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// -/// Structure modeling a pointer and stride into a tensor -template +/// Default mapping function from coordinates in a tensor's index space into the n-D array held +/// in memory. Assumes StorageRank = Rank +template +struct IdentityTensorMapFunc { + static int const kStorageRank = Rank; + CUTLASS_HOST_DEVICE + Coord operator()(Coord const &coord) const { + return coord; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/* \brief Structure modeling a pointer and stride into a tensor. + + A tensor consists of an index space with Rank_ dimensions. It is stored in memory modeled + as an n-D array, where n = StorageRank_. A mapping function maps the logical coordinates of the + tensor's index space into the n-D array, and a stride vector maps the n-D array to linear memory. + + CUTLASS requires the n-D array's least significant, "fastest changing" dimension to + be contiguous in memory. It therefore has a stride of 1 and is not stored. Construction is offered + from vectors of full StorageRank and of the 'compact' rank, though it is in error to construct + with the least significant stride != 1. + + The requirement that the least significant dimension be consecutive enables numerous optimizations + and assumptions about vectorizing memory accesses throughout CUTLASS. It also matches various + BLAS conventions in which only the "leading dimension" or most significant stride of a rank=2 + matrix is provided. + + This does affect the ability of constructing arbitrary "sparse" 2-D matrices in memory where all + stride elements are > 1. This can be overcome by defining a custom mapping function and a + StorageRank of 3 or more. + + + Examples: + + (These examples use helpers for matrix layouts defined in cutlass/matrix_traits.h) + + 1. Column-major matrix may be represented as a rank=2 tensor: + + TensorRef A(ptr_A, make_Coord(ldm, 1)); + + 2. Row-major matrix may be represented as a rank=2 tensor: + + TensorRef B(ptr_A, ldm); + + 3. An interleaved matrix may be represented as a rank=2 tensor: + + TensorRef > C; + + 4. Defining a sparse matrix with arbitrary strides in each dimension + + struct ContiguousLayout { + + /// Arbitrary storage rank + static int const kStorageRank = 3; + + /// Mapping function defined by runtime stride configuration + CUTLASS_HOST_DEVICE + Coord<3> operator()(MatrixCoord const &coord) const { + return make_Coord(coord.row(), coord.column(), 0); + } + }; + + typedef TensorRef ContiguousTensorRef; + + // Construct the TensorRef object from a pair of stride values + ContiguousTensorRef D(ptr_D, make_Coord(row_stride, column_stride)); + + + 5. A helper exists to define a TensorRef for a contiguous matrix whose layout + is not known at compile time. + + MatrixLayout::Kind layout; // Could be MatrixLayout::kRowMajor or MatrixLayout::kColumnMajor + int ldm; // leading dimension + + ContiguousTensorRef E(ptr_E, ContiguousLayout::stride(layout, ldm)); + +*/ +template < + /// Data type of element stored within tensor + typename Storage_, + /// Rank of logical tensor + int Rank_, + /// Maps a Coord in the logical tensor index space to the internal n-D array + typename MapFunc_ = IdentityTensorMapFunc, + /// Rank of internal n-D array + int StorageRank_ = MapFunc_::kStorageRank, + /// Index type used for coordinates + typename Index_ = int, + /// Index type used for offsets and pointer differences + typename LongIndex_ = long long +> class TensorRef { public: /// Data type of individual access typedef Storage_ Storage; - /// Rank of tensor - static int const Rank = Rank_; + /// Logical rank of tensor index space + static int const kRank = Rank_; + + /// Mapping function from logical coordinate to internal n-D array + typedef MapFunc_ MapFunc; + + /// Rank of internal storage + static int const kStorageRank = StorageRank_; + + /// Index type + typedef Index_ Index; + + /// Typically, strides in memory can be very large + typedef LongIndex_ LongIndex; + + /// Coordinate in logical tensor space + typedef Coord TensorCoord; + + /// Coordinate in storage n-D array + typedef Coord StorageCoord; + + /// Stride vector in storage coordinage space - assumes least significant stride + /// is 1 and does not store it. + typedef Coord StrideVector; + + /// Tensor reference to of constant value + typedef TensorRef< + typename platform::remove_const::type const, + Rank_, + MapFunc_, + StorageRank_, + Index_, + LongIndex_> ConstTensorRef; + + /// Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a + /// scalar, but degenerate cases such as these are difficult to accommodate without + /// extensive C++ metaprogramming or support for zero-length arrays. + static_assert(kRank > 0, "Cannot define a zero-rank TensorRef"); + + // + // Definitions included for backwards compatibility - to be removed in next major release + // + + /// Coordinate in logical tensor space + typedef TensorCoord Coord_t; + + /// Logical rank of tensor index space + static int const Rank = kRank; private: - // - // Data members - // - /// Pointer to storage element + /// Pointer Storage* ptr_; - /// Stride information - Coord stride_; + /// Stride vector - fastest-changing stride assumed to be 1 and not stored + StrideVector stride_; + + /// Maps a logical coordinate to an n-D array's tensor space + MapFunc coord_map_; public: + // // Methods // - /// Default ctor + /// Helper for 1-D memory. All higher ranks are projected onto the fastest changing rank. CUTLASS_HOST_DEVICE - TensorRef() : ptr_(nullptr) {} + TensorRef(Storage *ptr = nullptr): ptr_(ptr) { + for (int i = 0; i < kStorageRank - 1; ++i) { + stride_[i] = 1; + } + } - /// Constructs from a pointer, size, and stride + /// Helper to construct from a pointer and single stride element for 2-D pitch linear memory. + // Higher ranks are projected onto the fastest-changing rank. CUTLASS_HOST_DEVICE - TensorRef(Storage* ptr, Coord stride) : ptr_(ptr), stride_(stride) {} + TensorRef(Storage* ptr, Index ldm) { + ptr_ = ptr; + for (int i = 0; i < kStorageRank - 1; ++i) { + stride_[i] = ldm; + } + } + + /// Constructs from a single pointer and stride vector + CUTLASS_HOST_DEVICE + TensorRef(Storage* ptr, StrideVector const& stride) : ptr_(ptr), stride_(stride) { + + } + + /// Constructs from a pointer and a stride vector of size kRank. If fastest changing + /// stride is not 1, construction fails and subsequent calls to good() will return false. + CUTLASS_HOST_DEVICE + TensorRef(Storage* ptr, StorageCoord const& stride) { + // Fastest-changing stride must be one + if (stride.at(kStorageRank - 1) == 1) { + ptr_ = ptr; + for (int i = 0; i < kStorageRank - 1; ++i) { + stride_[i] = stride[i]; + } + } + else { + // Fastest-chaning stride must be 1. + reset(); + } + } + + /// Enables conversion from TensorRef of non-const type + CUTLASS_HOST_DEVICE + TensorRef( + TensorRef< + typename platform::remove_const::type, + kRank, + MapFunc, + kStorageRank, + Index, + LongIndex> const &ref + ): + ptr_(ref.data()) { + for (int i = 0; i < kStorageRank - 1; ++i) { + stride_[i] = ref.stride(i); + } + } + + /// Returns a reference to constant-valued tensor + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(*this); + } + + /// Updates only the pointer + CUTLASS_HOST_DEVICE + void reset(Storage* ptr = nullptr) { + ptr_ = ptr; + } /// Updates the pointer, stride, and location within a TensorRef CUTLASS_HOST_DEVICE - void reset(Storage* ptr = nullptr, Coord stride = Coord(0)) { - ptr_ = ptr; - stride_ = stride; - } - - /// Conversion function - template - TensorRef convert() { - Coord converted_stride; - for (int i = 0; i < Rank - 1; ++i) { - converted_stride[i] = stride_[i] * Extent::kValue / Extent::kValue; + void reset(Storage* ptr, StorageCoord const & stride) { + // Fastest-changing stride must be one + if (stride.at(kStorageRank - 1) == 1) { + ptr_ = ptr; + for (int i = 0; i < kStorageRank - 1; ++i) { + stride_[i] = stride[i]; + } + } + else { + // Fastest-changing stride must be 1 - this is an error. + reset(); } - converted_stride[Rank - 1] = stride_[Rank - 1]; - - return TensorRef(reinterpret_cast(ptr_), converted_stride); } /// Returns true if the TensorRef may be safely accessed CUTLASS_HOST_DEVICE - bool good() const { return ptr_ != nullptr; } + bool good() const { + return ptr_ != nullptr; + } /// Returns the pointer to referenced data CUTLASS_HOST_DEVICE - Storage* data() const { return ptr_; } + Storage * data() const { return ptr_; } /// Returns the stride of the tensor CUTLASS_HOST_DEVICE - Coord const& stride() const { return stride_; } + StorageCoord stride() const { + StorageCoord ld; + for (int i = 0; i < kStorageRank - 1; ++i) { + ld[i] = stride_[i]; + } + ld[kStorageRank - 1] = 1; + return ld; + } /// Returns the stride of the tensor in the given dimension CUTLASS_HOST_DEVICE - int const& stride(int dim) const { return stride_.at(dim); } + Index stride(int dim) const { + // fastest-changing stride assumbed to be 1 + if (dim + 1 >= kStorageRank) { + return 1; + } + return stride_.at(dim); + } /// Returns the maximum stride element as the 'leading dimension' CUTLASS_HOST_DEVICE - int leading_dim() const { return __NV_STD_MAX(stride_[1], stride_[2]); } + Index leading_dim(int idx = 0) const { return stride(idx); } + + /// Maps a logical coordinate to an n-D array in memory + CUTLASS_HOST_DEVICE + StorageCoord map(TensorCoord const &coord) const { + return coord_map_(coord); + } /// Computes the offset of an index from the origin of the tensor CUTLASS_HOST_DEVICE - long long offset(Coord const& coord) const { - return stride_.template dot(coord); + LongIndex offset(TensorCoord const& coord) const { + return stride().template dot(map(coord)); } /// Returns a reference to the element at a given Coord CUTLASS_HOST_DEVICE - Storage& at(Coord const& coord) const { return ptr_[offset(coord)]; } + Storage& at(TensorCoord const& coord) const { + return ptr_[offset(coord)]; + } - /// Element-wise accessor - Storage& operator[](Coord const& coord) const { return at(coord); } + /// Returns a reference to the element at a given linear index + CUTLASS_HOST_DEVICE + Storage& at(LongIndex idx) const { return ptr_[idx]; } /// Returns a reference to the element at a given Coord CUTLASS_HOST_DEVICE - Storage& at(int idx) const { return ptr_[idx]; } + Storage& operator[](TensorCoord const& coord) const { + return ptr_[offset(coord)]; + } - /// Element-wise accessor - Storage& operator[](int idx) const { return at(idx); } - - /// Adds an offset to the pointer + /// Returns a reference to the element at a given linear index CUTLASS_HOST_DEVICE - TensorRef& advance(Coord const& b) { - ptr_ += offset(b); + Storage& operator[](LongIndex idx) const { return ptr_[idx]; } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRef & add_pointer_offset(LongIndex delta) { + ptr_ += delta; return *this; } /// Returns a TensorRef offset by a given amount CUTLASS_HOST_DEVICE - TensorRef operator+(Coord const& b) const { return TensorRef(ptr_ + offset(b), stride_); } + TensorRef operator+(TensorCoord const& b) const { + TensorRef result(*this); + result.add_pointer_offset(offset(b)); + return result; + } /// Returns a TensorRef offset by a given amount CUTLASS_HOST_DEVICE - TensorRef operator-(Coord const& b) const { return TensorRef(ptr_ - offset(b), stride_); } + TensorRef& operator+=(TensorCoord const& b) { + add_pointer_offset(offset(b)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef operator-(TensorCoord const& b) const { + TensorRef result(*this); + result.add_pointer_offset(-offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef& operator-=(TensorCoord const& b) { + add_pointer_offset(-offset(b)); + return *this; + } }; -//////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partial specializations to handle degenerate cases. +// +/////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass +/// Specialization for rank=1 case with no internal StrideVector +template < + /// Data type of element stored within tensor + typename Storage_, + /// Rank of logical tensor + int Rank_, + /// Maps a Coord in the logical tensor index space to the internal n-D array + typename MapFunc_, + /// Index type used for coordinates + typename Index_, + /// Index type used for offsets and pointer differences + typename LongIndex_ +> +class TensorRef { + public: + /// Data type of individual access + typedef Storage_ Storage; + + /// Logical rank of tensor index space + static int const kRank = Rank_; + + /// Mapping function from logical coordinate to internal n-D array + typedef MapFunc_ MapFunc; + + /// Rank of internal storage + static int const kStorageRank = 1; + + /// Index type + typedef Index_ Index; + + /// Typically, strides in memory can be very large + typedef LongIndex_ LongIndex; + + /// Coordinate in logical tensor space + typedef Coord TensorCoord; + + /// Coordinate in storage n-D array + typedef Coord StorageCoord; + + /// Stride vector in storage coordinage space - assumes least significant stride + /// is 1 and does not store it. + struct StrideVector { }; + + /// Tensor reference to of constant value + typedef TensorRef< + typename platform::remove_const::type const, + Rank_, + MapFunc_, + kStorageRank, + Index_, + LongIndex_> ConstTensorRef; + + // + // Definitions included for backwards compatibility - to be removed in next major release + // + + /// Coordinate in logical tensor space + typedef TensorCoord Coord_t; + + /// Logical rank of tensor index space + static int const Rank = kRank; + + private: + + /// Pointer + Storage* ptr_; + + /// Maps a logical coordinate to an n-D array's tensor space + MapFunc coord_map_; + + public: + + // + // Methods + // + + /// Helper for 1-D memory. All higher ranks are projected onto the fastest changing rank. + CUTLASS_HOST_DEVICE + TensorRef(Storage *ptr = nullptr): ptr_(ptr) { } + + /// Constructs from a single pointer and stride vector + CUTLASS_HOST_DEVICE + TensorRef(Storage* ptr, StrideVector const& stride) : ptr_(ptr) { + + } + + /// Constructs from a pointer and a stride vector of size kRank. If fastest changing + /// stride is not 1, construction fails and subsequent calls to good() will return false. + CUTLASS_HOST_DEVICE + TensorRef(Storage* ptr, StorageCoord const& stride) { + // Fastest-changing stride must be one + if (stride.at(kStorageRank - 1) == 1) { + ptr_ = ptr; + } + else { + // Fastest-chaning stride must be 1. + reset(); + } + } + + /// Enables conversion from TensorRef of non-const type + CUTLASS_HOST_DEVICE + TensorRef( + TensorRef< + typename platform::remove_const::type, + kRank, + MapFunc, + kStorageRank, + Index, + LongIndex> const &ref + ): + ptr_(ref.data()) { + } + + /// Returns a reference to constant-valued tensor + CUTLASS_HOST_DEVICE + ConstTensorRef const_ref() const { + return ConstTensorRef(*this); + } + + /// Updates only the pointer + CUTLASS_HOST_DEVICE + void reset(Storage* ptr = nullptr) { + ptr_ = ptr; + } + + /// Updates the pointer, stride, and location within a TensorRef + CUTLASS_HOST_DEVICE + void reset(Storage* ptr, StorageCoord const & stride) { + // Fastest-changing stride must be one + if (stride.at(kStorageRank - 1) == 1) { + ptr_ = ptr; + } + else { + // Fastest-changing stride must be 1 - this is an error. + reset(); + } + } + + /// Returns true if the TensorRef may be safely accessed + CUTLASS_HOST_DEVICE + bool good() const { + return ptr_ != nullptr; + } + + /// Returns the pointer to referenced data + CUTLASS_HOST_DEVICE + Storage * data() const { return ptr_; } + + /// Returns the stride of the tensor + CUTLASS_HOST_DEVICE + StorageCoord stride() const { + StorageCoord ld; + ld[kStorageRank - 1] = 1; + return ld; + } + + /// Returns the stride of the tensor in the given dimension + CUTLASS_HOST_DEVICE + Index stride(int dim) const { + // fastest-changing stride assumbed to be 1 + return 1; + } + + /// Returns the maximum stride element as the 'leading dimension' + CUTLASS_HOST_DEVICE + Index leading_dim(int idx = 0) const { return 1; } + + /// Maps a logical coordinate to an n-D array in memory + CUTLASS_HOST_DEVICE + StorageCoord map(TensorCoord const &coord) const { + return coord_map_(coord); + } + + /// Computes the offset of an index from the origin of the tensor + CUTLASS_HOST_DEVICE + LongIndex offset(TensorCoord const& coord) const { + return stride().template dot(map(coord)); + } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Storage& at(TensorCoord const& coord) const { + return ptr_[offset(coord)]; + } + + /// Returns a reference to the element at a given linear index + CUTLASS_HOST_DEVICE + Storage& at(LongIndex idx) const { return ptr_[idx]; } + + /// Returns a reference to the element at a given Coord + CUTLASS_HOST_DEVICE + Storage& operator[](TensorCoord const& coord) const { + return ptr_[offset(coord)]; + } + + /// Returns a reference to the element at a given linear index + CUTLASS_HOST_DEVICE + Storage& operator[](LongIndex idx) const { return ptr_[idx]; } + + /// Adds an offset to each pointer + CUTLASS_HOST_DEVICE + TensorRef & add_pointer_offset(LongIndex delta) { + ptr_ += delta; + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef operator+(TensorCoord const& b) const { + TensorRef result(*this); + result.add_pointer_offset(offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef& operator+=(TensorCoord const& b) { + add_pointer_offset(offset(b)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef operator-(TensorCoord const& b) const { + TensorRef result(*this); + result.add_pointer_offset(-offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorRef& operator-=(TensorCoord const& b) { + add_pointer_offset(-offset(b)); + return *this; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/cutlass/tensor_ref_collection.h b/cutlass/tensor_ref_collection.h new file mode 100644 index 000000000..b2972e184 --- /dev/null +++ b/cutlass/tensor_ref_collection.h @@ -0,0 +1,420 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Introduces TensorRefCollection concept and defines TensorRefBatch and TensorRefArray. +*/ + +#pragma once + +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// +// TensorRefCollection is a concept for storing a logical collection of TensorRef objects. Classes +// satisfying the TensorRefCollection concept must support the following: +// +// // Define storage type +// typedef typename TensorRefCollection::Storage Storage; +// +// // Define a type for offsets in memory +// typedef typename TensorRefCollection::LongIndex LongIndex; +// +// // Define a ConstIterator type satisfying TensorRefIterator +// typedef typename TensorRefCollection::ConstIterator TensorRefIterator; +// +// // Implement a begin() method. +// TensorRefIterator iterator = collection.begin(); +// +// +// TensorRefIterator is a concept for accessing an element in a TensorRefCollection. Classes +// satisfying the TensorRefIterator concept must support the following: +// +// // Define a TensorRef type accessed by the iterator +// typedef typename TensorRefIterator::TensorRef TensorRef; +// +// // Access the TensorRef +// TensorRef ref = *iterator; +// +// // Pre-increment and post-increment +// ++iterator; +// iterator++; +// +// // Pre-decrement and post-decrement +// --iterator; +// iterator--; +// +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// This satisfies TensorRefCollection and stores a collection of TensorRef objects that +/// have identical strides. TensorRef objects are separated by a linear stride. +template < + /// Data type of element stored within tensor + typename Storage_, + /// Rank of logical tensor + int Rank_, + /// Maps a Coord in the logical tensor index space to the internal n-D array + typename MapFunc_ = IdentityTensorMapFunc, + /// Rank of internal n-D array + int StorageRank_ = MapFunc_::kStorageRank, + /// Index type used for coordinates + typename Index_ = int, + /// Index type used for offsets and pointer differences + typename LongIndex_ = long long +> +struct TensorRefBatchStrided: + public TensorRef { + + // + // Type definitions + // + + /// Underlying TensorRef type + typedef TensorRef Base; + + /// Storage type + typedef typename Base::Storage Storage; + + /// Index type + typedef Index_ Index; + + /// Typically, strides in memory can be very large + typedef LongIndex_ LongIndex; + + /// Coordinate in logical tensor space + typedef Coord TensorCoord; + + /// Tensor reference implied by the TensorRefBatchStrided + typedef Base TensorRef; + + /// Constant iterator over tensors implied by TensorRefBatchStrided + class ConstIterator { + public: + /// TensorRef returned by the iterator + typedef Base TensorRef; + + private: + + /// Reference to the parent TensorBatchRef object + TensorRefBatchStrided const &ref_; + + /// Offset from the base TensorRef pointer + LongIndex offset_; + + public: + + /// Constructs a ConstIterator from a parent TensorRefBatchStrided + CUTLASS_HOST_DEVICE + ConstIterator( + TensorRefBatchStrided const &ref, + LongIndex offset = 0): ref_(ref), offset_(offset) { } + + /// Obtains a TensorRef pointed to by the iterator + CUTLASS_HOST_DEVICE + TensorRef *operator() const { + TensorRef ref(ref_); + ref.add_pointer_offset(offset_); + return ref; + } + + /// Advances the iterator to point to the next tensor + CUTLASS_HOST_DEVICE + ConstIterator &operator++() { + offset_ += ref_.tensor_stride; + return *this; + } + + /// Advances the iterator to point to the next tensor + CUTLASS_HOST_DEVICE + ConstIterator operator++(int) { + ConstIterator ret(*this); + offset_ += ref_.tensor_stride; + return ret; + } + + /// Returns an iterator advanced by (idx) amount + CUTLASS_HOST_DEVICE + ConstIterator operator+(Index idx) { + return ConstIterator(ref, offset_ + ref_.tensor_stride * idx); + } + + /// Advances this iterator by (idx) and returns a reference to self + CUTLASS_HOST_DEVICE + ConstIterator &operator+=(Index idx) { + offset_ += ref_.tensor_stride * idx; + return *this; + } + + /// Moves to the previous tensor + CUTLASS_HOST_DEVICE + ConstIterator &operator--() { + offset_ -= ref_.tensor_stride; + return *this; + } + + /// Moves to the previous tensor + CUTLASS_HOST_DEVICE + ConstIterator operator--(int) { + ConstIterator ret(*this); + offset_ -= ref_.tensor_stride; + return ret; + } + + /// Returns an iterator moved forward by (idx) amount + CUTLASS_HOST_DEVICE + ConstIterator operator-(Index idx) { + return ConstIterator(ref_, offset_ - ref_.tensor_stride * idx); + } + + /// Moves this iterator by (idx) and returns a reference to self + CUTLASS_HOST_DEVICE + ConstIterator &operator-=(Index idx) { + offset_ -= ref_.tensor_stride * idx; + return *this; + } + + /// Returns the difference in offset between two iterators + CUTLASS_HOST_DEVICE + Stride operator-(ConstIterator const &it) { + return offset_ - it.offset_; + } + }; + + // + // Data members + // + + /// Stride between tensors + LongIndex tensor_stride; + + // + // Methods + // + + // Default ctor + CUTLASS_HOST_DEVICE + TensorRefBatchStrided(): tensor_stride(0) { } + + // Constructs form a tensor reference and + CUTLASS_HOST_DEVICE + TensorRefBatchStrided(TensorRef const &ref, LongIndex _tensor_stride = 0): + TensorRef(ref), + tensor_stride(_tensor_stride) { } + + /// Gets the pointer offset + CUTLASS_HOST_DEVICE + LongIndex get_pointer_offset(Index idx) const { + return idx * tensor_stride; + } + + // Returns a reference + CUTLASS_HOST_DEVICE + TensorRef at(Index idx) const { + TensorRef ref(*this); + ref.add_pointer_offset(get_pointer_offset(idx)); + return ref; + } + + /// Returns an iterator + CUTLASS_HOST_DEVICE + ConstIterator begin() { + return ConstIterator(*this); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// This satisfies TensorRefCollection and stores a collection of TensorRef objects. This is a +/// structure of arrays in that the individual members of the TensorRef are held in distinct arrays. +/// +/// Note, TensorRef maps a logical coordinate space to an n-D array with rank kStorageRank. It +/// maintains a stride vector of similar rank, but the least significant rank is defined to be 1. +/// +/// The least significant stride of 1 is not stored, and therefore the number of stride arrays is +/// kStorageRank - 1. +template < + /// Data type of element stored within tensor + typename Storage_, + /// Rank of logical tensor + int Rank_, + /// Maps a Coord in the logical tensor index space to the internal n-D array + typename MapFunc_ = IdentityTensorMapFunc, + /// Rank of internal n-D array + int StorageRank_ = MapFunc_::kStorageRank, + /// Index type used for coordinates + typename Index_ = int, + /// Index type used for offsets and pointer differences + typename LongIndex_ = long long +> +struct TensorRefArray { + // + // Type definitions + // + + /// TensorRef type obtained from the TensorRefArray + typedef TensorRef TensorRef; + + /// Element pointed to by the TensorRef + typedef Storage_ Storage; + + /// Index type + typedef Index_ Index; + + /// Typically, strides in memory can be very large + typedef LongIndex_ LongIndex; + + /// Rank of the stride vector + static int const kStorageRank = TensorRef::kStorageRank; + + /// TensorRefIterator over TensorRef objects in TensorRefArray + class ConstIterator { + public: + + /// TensorRef returned by the iterator + typedef Base TensorRef; + + private: + /// Reference to the TensorRefArray + TensorRefArray const &ref_; + + /// Index into TensorRefArray + int idx_; + + public: + + /// Constructs a ConstIterator over the TensorRef objects + CUTLASS_HOST_DEVICE + ConstIterator(TensorArrayRef const &ref, int idx = 0): ref_(ref), idx_(idx) { } + + /// Obtains a TensorRef pointed to by this iterator + CUTLASS_HOST_DEVICE + TensorRef *operator() const { + return ref_.reference(idx_); + } + + /// Advances to next TensorRef + CUTLASS_HOST_DEVICE + ConstIterator &operator++() { + ++idx_; + return *this; + } + + /// Advances to next TensorRef + CUTLASS_HOST_DEVICE + ConstIterator operator++(int) { + ConstIterator ret(*this); + idx_ ++; + return ret; + } + + CUTLASS_HOST_DEVICE + ConstIterator operator+(Index idx) { + return ConstIterator(ref_, idx_ + idx); + } + + CUTLASS_HOST_DEVICE + ConstIterator &operator+=(Index idx) { + idx_ += idx; + return *this; + } + + CUTLASS_HOST_DEVICE + ConstIterator &operator--() { + --idx_; + return *this; + } + + /// Advances to next TensorRef + CUTLASS_HOST_DEVICE + ConstIterator operator--(int) { + ConstIterator ret(*this); + --idx_; + return ret; + } + + CUTLASS_HOST_DEVICE + ConstIterator &operator-=(Index idx) { + idx_ -= idx; + return *this; + } + + CUTLASS_HOST_DEVICE + ConstIterator operator-(Index idx) { + return ConstIterator(ref_, idx_ + idx); + } + }; + + // + // Data members + // + + /// Base addresses + Storage **pointers; + + /// Array of strides + Index *strides[kStorageRank - 1]; + + // + // Methods + // + + // Default ctor + CUTLASS_HOST_DEVICE + TensorArrayRef() { } + + // Construct from pointers to arrays to strides + CUTLASS_HOST_DEVICE + TensorArrayRef( + Storage **_pointers, + Index _strides[kStorageRank - 1]): pointers(_pointers) { + + // Copy pointers to strides arrays + for (int i = 0; i < kStorageRank - 1; ++i) { + strides[i] = _strides[i]; + } + } + + // Returns a TensorRef at the given index in the collection + CUTLASS_HOST_DEVICE + TensorRef at(Index idx) const { + Coord stride; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStorageRank - 1; ++i) { + stride[i] = stride_[idx][i]; + } + return TensorRef(pointers[idx], stride); + } + + /// Returns an TesnorRefIterator over the TensorRef objects in this collection + CUTLASS_HOST_DEVICE + ConstIterator begin() { + return ConstIterator(*this); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/cutlass/tensor_view.h b/cutlass/tensor_view.h index 89c6bd571..4ef99e027 100644 --- a/cutlass/tensor_view.h +++ b/cutlass/tensor_view.h @@ -24,51 +24,110 @@ **************************************************************************************************/ /*! \file \brief Defines a structure containing strides and a pointer to tensor data. + + TensorView is derived from TensorRef and contributes bounds to the tensor's index space. Thus, + it is a complete mathematical object and may be used in tensor algorithms. It is decoupled from + data storage and is therefore lightweight and may be embedded in larger tensor objects or + memory structures. + + See cutlass/tensor_ref.h for more details about the mapping of the logical tensor index space to + linear memory. */ #pragma once #include -#include -#include +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// -/// Host-side reference implementation of tensor operations -template -class TensorView : public TensorRef { +/// Defines a view into a logical tensor +template < + /// Data type of element stored within tensor + typename Storage_, + /// Rank of logical tensor + int Rank_ = 4, + /// Maps a Coord in the logical tensor index space to the internal n-D array + typename MapFunc_ = IdentityTensorMapFunc, + /// Rank of internal n-D array + int StorageRank_ = MapFunc_::kStorageRank, + /// Index type used for coordinates + typename Index_ = int, + /// Index type used for offsets and pointer differences + typename LongIndex_ = long long +> +class TensorView : public TensorRef { public: - /// Reference and stride - typedef TensorRef Base; + /// Base tensor reference + typedef TensorRef Base; - /// Reference and stride - typedef Base TensorRef_t; + /// Tensor reference to of constant value + typedef TensorRef< + typename platform::remove_const::type const, + Rank_, + MapFunc_, + StorageRank_, + Index_, + LongIndex_> ConstTensorRef; - /// Reference to constant type - typedef TensorRef ConstTensorRef_t; + /// Base tensor reference + typedef Base TensorRef; - /// Rank of tensor - static int const Rank = TensorRef_t::Rank; + /// Storage type + typedef typename Base::Storage Storage; + + /// Index type + typedef typename Base::Index Index; + + /// Coordinate in logical tensor space + typedef typename TensorRef::TensorCoord TensorCoord; + + /// Coordinate in storage n-D array + typedef typename TensorRef::StorageCoord StorageCoord; + + /// Stride vector in storage coordinate space + /// Least significant stride is = 1 and not stored + typedef typename TensorRef::StrideVector StrideVector; + + /// TensorView of constant value + typedef TensorView< + typename platform::remove_const::type const, + Rank_, + MapFunc_, + StorageRank_, + Index_, + LongIndex_> ConstTensorView; + + // + // Definitions included for backwards compatibility - to be removed in next major release + // + + /// Coordinate in logical tensor space + typedef TensorCoord Coord_t; + + /// Logical rank of tensor index space + static int const Rank = Base::kRank; /// Type used to compute the offset of an element to the base of a tensor - typedef int Offset_t; + typedef typename Base::LongIndex Offset_t; - /// Coordinate into tensor - typedef Coord Coord_t; + /// Base class + typedef TensorRef TensorRef_t; + + /// TensorRef to const-valued type + typedef typename TensorRef::ConstTensorRef ConstTensorRef_t; private: // // Data members // - /// Pointer to pitch-linear memory - TensorRef_t ref_; - /// Dimensions of coordinate (independent of stride) - Coord_t size_; + TensorCoord size_; public: // @@ -79,91 +138,126 @@ class TensorView : public TensorRef { CUTLASS_HOST_DEVICE TensorView() {} - /// Constructs a Tensor_view from a TensorRef and size + /// Constructs a TensorView from a TensorRef and size CUTLASS_HOST_DEVICE - TensorView(TensorRef_t const& _ref, Coord_t const& _size) : Base(_ref), size_(_size) {} + TensorView(Base const& _ref, TensorCoord const& _size) : Base(_ref), size_(_size) {} - /// Returns true if the Tensor_view is bound to some memory + /// Constructs a TensorView from a pointer, a stride vector, and size CUTLASS_HOST_DEVICE - bool good() const { return ref().good(); } + TensorView( + Storage *ptr, + StrideVector const &stride, + TensorCoord const& size + ): + Base(ptr, stride), size_(size) {} - /// Returns a pointer to data + /// Constructs a TensorView from a pointer, a stride vector, and size CUTLASS_HOST_DEVICE - T* data() const { return ref().data(); } + TensorView( + Storage *ptr, + StorageCoord const &stride, + TensorCoord const& size + ): + Base(ptr, stride), size_(size) {} /// Updates the reference and size of a Tensor_view object CUTLASS_HOST_DEVICE - void reset(TensorRef_t const& _ref = TensorRef_t(0), Coord_t const& _size = Coord_t()) { + void reset(Base const& _ref = Base(), TensorCoord const& _size = TensorCoord()) { Base::operator=(_ref); size_ = _size; } - /// Accesses the tensor reference pointing to data + /// Accesses the size CUTLASS_HOST_DEVICE - TensorRef_t& ref() { return *this; } - - /// - CUTLASS_HOST_DEVICE - ConstTensorRef_t const_ref() { return ConstTensorRef_t(data(), stride()); } - - /// Accesses the tensor reference pointing to data - CUTLASS_HOST_DEVICE - TensorRef_t const& ref() const { return *this; } + TensorCoord const& size() const { return size_; } /// Accesses the size CUTLASS_HOST_DEVICE - Coord_t const& size() const { return size_; } - - /// Accesses the size - CUTLASS_HOST_DEVICE - int size(int dim) const { return size_.at(dim); } - - /// Accesses the stride - CUTLASS_HOST_DEVICE - Coord_t const& stride() const { return ref().stride(); } - - /// Accesses the stride - CUTLASS_HOST_DEVICE - int const& stride(int dim) const { return ref().stride(dim); } + Index size(int dim) const { return size_.at(dim); } /// Assigns the Tensor_view CUTLASS_HOST_DEVICE TensorView& operator=(TensorView const& _tensor) { - Base::operator=(_tensor._ref); + Base::operator=(_tensor); size_ = _tensor.size_; return *this; } - /// Returns the index of an element - CUTLASS_HOST_DEVICE - Offset_t offset(Coord_t const& coord) const { return ref().offset(coord); } - /// Determines whether a location is within a tensor CUTLASS_HOST_DEVICE - bool contains(Coord_t const& coord) const { - for (int dim = 0; dim < Rank; ++dim) { - if (coord.at(dim) >= size_.at(dim)) { + bool contains(TensorCoord const& coord) const { + CUTLASS_PRAGMA_UNROLL + for (int dim = 0; dim < Rank_; ++dim) { + if (coord[dim] >= size_[dim]) { return false; } } return true; } - /// Element-wise accessor + /// Returns a TensorRef pointing to the first element of the tensor. CUTLASS_HOST_DEVICE - T& at(Coord_t const& coord) const { return ref().at(coord); } + TensorRef ref() const { + return TensorRef(*this); + } - /// Element-wise accessor - T& operator[](Coord const& coord) const { return at(coord); } - - /// Element-wise accessor + /// Returns a TensorRef pointing to the first element of the tensor. CUTLASS_HOST_DEVICE - T& at(Offset_t idx) const { return ref().at(idx); } + ConstTensorRef const_ref() const { + return ConstTensorRef(*this); + } /// Returns a Tensor_view given location and size quantities CUTLASS_HOST_DEVICE - TensorView subview(Coord_t const& location, Coord_t size) const { - return TensorView(ref() + location, size.clamp(size_ - location)); + TensorView subview(TensorCoord const& location, TensorCoord size) const { + return TensorView((*this) + location, size.clamp(size_ - location)); + } + + /// Returns the number of scalar elements needed to store tensor + CUTLASS_HOST_DEVICE + size_t capacity() const { + int max_rank = 0; + + StorageCoord mapped_size(this->map(size())); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Base::kStorageRank; ++i) { + if (!i || + this->stride(i) * mapped_size[i] > this->stride(max_rank) * mapped_size[max_rank]) { + max_rank = i; + } + } + return this->stride(max_rank) * mapped_size[max_rank]; + } + + /// Returns a TensorView offset by a given amount + CUTLASS_HOST_DEVICE + TensorView operator+(TensorCoord const& b) const { + TensorView result(*this); + result.add_pointer_offset(this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorView& operator+=(TensorCoord const& b) { + this->add_pointer_offset(this->offset(b)); + return *this; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorView operator-(TensorCoord const& b) const { + TensorRef result(*this); + result.add_pointer_offset(-this->offset(b)); + return result; + } + + /// Returns a TensorRef offset by a given amount + CUTLASS_HOST_DEVICE + TensorView& operator-=(TensorCoord const& b) { + this->add_pointer_offset(-this->offset(b)); + return *this; } }; diff --git a/cutlass/tile_allocation.h b/cutlass/tile_allocation.h new file mode 100644 index 000000000..81db797f9 --- /dev/null +++ b/cutlass/tile_allocation.h @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a fragment based on a Shape<> template. +*/ +#pragma once + +#include "cutlass/shape.h" +#include "cutlass/fragment.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/zip_tensor_ref.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Class for storing a tile in memory and accessing it through a tensor ref +template +struct TileAllocation { + // + // Type definitions + // + + /// Scalar element + typedef Scalar_ Scalar; + + /// The actual storage (may differ from the scalar type) + typedef typename StorageType::Type Storage; + + /// Size of the allocation in units of scalars + typedef Shape_ Shape; + + /// Strides + typedef typename ShapeStrides::Shape Strides; + + /// Defines the tensor reference for this allocation + typedef TensorRef ConstTensorRef; + + /// Defines the tensor reference for this allocation + typedef TensorRef TensorRef; + + // + // Data members + // + + /// Storage + Storage storage[Shape::kD][Shape::kH][Shape::kW][Shape::kC]; + + // + // Methods + // + + /// Returns a pointer to the raw data + CUTLASS_DEVICE + Scalar *data() { return reinterpret_cast(&storage[0][0][0][0]); } + + /// Returns a const pointer to the raw data + CUTLASS_DEVICE + Scalar const *data() const { return reinterpret_cast(&storage[0][0][0][0]); } + + /// Returns a TensorRef object pointing to the data + CUTLASS_DEVICE + TensorRef reference() { + return TensorRef(data(), make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)); + } + + /// Returns a TensorRef object pointing to the data + CUTLASS_DEVICE + ConstTensorRef reference() const { + return ConstTensorRef(data(), make_Coord(Strides::kD, Strides::kH, Strides::kW, Strides::kC)); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Manages a pair of tile allocations as if they are one allocation +template +struct ZipTileAllocation { + // + // Type definitions + // + + /// First tensor allocation + typedef First_ First; + + /// Second tensor allocation + typedef Second_ Second; + + /// Defines the tensor reference for this allocation + typedef ZipTensorRef TensorRef; + + /// Defines the tensor reference for this allocation + typedef ZipTensorRef + ConstTensorRef; + + // + // Data members + // + + /// First tensor allocation + First first; + + /// Second tensor allocation + Second second; + + // + // Methods + // + + /// Returns a TensorRef object pointing to the data + CUTLASS_DEVICE + TensorRef reference() { return TensorRef(first.reference(), second.reference()); } + + /// Returns a TensorRef object pointing to the data + CUTLASS_DEVICE + ConstTensorRef reference() const { return ConstTensorRef(first.reference(), second.reference()); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/cutlass/tile_coord.h b/cutlass/tile_coord.h new file mode 100644 index 000000000..b3d809bc3 --- /dev/null +++ b/cutlass/tile_coord.h @@ -0,0 +1,194 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a coordinate used for the CUTLASS 4-D tile structure. +*/ + +#pragma once + +#include "cutlass/coord.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// TileCoord wraps Coord<4, int> to provide a helper for accessing named dimensions. Classes +/// expecting a coordinate in the rank=4 index space of a CUTLASS tile structure should use TileCoord. +template +struct TileCoord : public Coord<4, Index_> { + + /// Index type + typedef Index_ Index; + + /// Underlying Coord<4> + typedef Coord<4, Index> Base; + + /// D dimension + static int kD = 0; + + /// H dimension + static int kH = 1; + + /// W dimension + static int kW = 2; + + /// C dimension + static int kC = 3; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + TileCoord() { } + + /// Constructs from Coord<3> and infers coord[kC] = 0 + CUTLASS_HOST_DEVICE + TileCoord(Coord<3, Index> const &coord): + Base(make_Coord(coord[0], coord[1], coord[2], 0)) { } + + /// Constructs from Coord<4> + CUTLASS_HOST_DEVICE + TileCoord(Coord<4, Index> const &coord): Base(coord) { } + + /// Constructs from an array of coordinate elements + CUTLASS_HOST_DEVICE + TileCoord(Index coord[4]): Base(coord) { } + + /// Helper to construct from a row and column + CUTLASS_HOST_DEVICE + TileCoord(Index d, Index h, Index w, Index c): Base(make_Coord(d, h, w, c)) { } + + /// Returns the D element of the coordinate + CUTLASS_HOST_DEVICE + Index const & d() const { return this->at(kD); } + + /// Returns the D element of the coordinate + CUTLASS_HOST_DEVICE + Index & d() { return this->at(kD); } + + /// Returns the H element of the coordinate + CUTLASS_HOST_DEVICE + Index const & h() const { return this->at(kH); } + + /// Returns the H element of the coordinate + CUTLASS_HOST_DEVICE + Index & h() { return this->at(kH); } + + /// Returns the W element of the coordinate + CUTLASS_HOST_DEVICE + Index const & w() const { return this->at(kW); } + + /// Returns the W element of the coordinate + CUTLASS_HOST_DEVICE + Index & w() { return this->at(kW); } + + /// Returns the Celement of the coordinate + CUTLASS_HOST_DEVICE + Index const & c() const { return this->at(kC); } + + /// Returns the C element of the coordinate + CUTLASS_HOST_DEVICE + Index & c() { return this->at(kC); } + + /// Gets H and W dimensions as a Coord<2> + CUTLASS_HOST_DEVICE + Coord<2> hw() const { + return make_Coord(h(), w()); + } + + /// Gets H, W, and C dimensions as a Coord<3> + CUTLASS_HOST_DEVICE + Coord<3> hwc() const { + return make_Coord(h(), w(), c()); + } + + /// Gets D, H, and W dimensions as a Coord<3> + CUTLASS_HOST_DEVICE + Coord<3> dhw() const { + return make_Coord(d(), h(), w()); + } + + // + // Coord operators + // + + /// Element-wise addition + CUTLASS_HOST_DEVICE + TileCoord operator+(Base const& b) const { + return TileCoord(Base::operator+(b)); + } + + /// Element-wise subtraction + CUTLASS_HOST_DEVICE + TileCoord operator-(Base const& b) const { + return TileCoord(Base::operator-(b)); + } + + /// Element-wise multiplication + CUTLASS_HOST_DEVICE + TileCoord operator*(Base const& b) const { + return TileCoord(Base::operator*(b)); + } + + /// Element-wise division + CUTLASS_HOST_DEVICE + TileCoord operator/(Base const& b) const { + return TileCoord(Base::operator/(b)); + } + + /// In-place addition + CUTLASS_HOST_DEVICE + TileCoord& operator+=(Base const& b) { + Base::operator+=(b); + return *this; + } + + /// In-place subtraction + CUTLASS_HOST_DEVICE + TileCoord& operator-=(Base const& b) { + Base::operator-=(b); + return *this; + } + + /// In-place multiplication + CUTLASS_HOST_DEVICE + TileCoord& operator*=(Base const& b) { + Base::operator*=(b); + return *this; + } + + /// In-place division + CUTLASS_HOST_DEVICE + TileCoord& operator/=(Base const& b) { + Base::operator/=(b); + return *this; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/cutlass/tile_iterator.h b/cutlass/tile_iterator.h index 5d39c4f80..51e577949 100644 --- a/cutlass/tile_iterator.h +++ b/cutlass/tile_iterator.h @@ -28,10 +28,13 @@ */ #pragma once -#include -#include -#include -#include +#include "cutlass/coord.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/fragment.h" +#include "cutlass/load_store.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/vector.h" +#include namespace cutlass { @@ -61,12 +64,6 @@ as a Coord<4>. struct IteratorAdvance { enum Kind { kD, kH, kW }; }; - -/// Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix -struct IteratorFragment { - enum Kind { kScalar, kWmmaMatrix }; -}; - /////////////////////////////////////////////////////////////////////////////////////////////////// /** @@ -77,7 +74,7 @@ template + int AccessSize> struct TileTraits { /// Shape of the tile typedef Tile_ Tile; @@ -89,11 +86,52 @@ struct TileTraits { typedef Iterations_ Iterations; /// Functor that returns the logical coordinate of each entity's initial offset in the tile + // + // ThreadOffset should be a functor defined like: + // + // struct ThreadOffsetExample { + // CUTLASS_DEVICE + // Coord<4> operator()() const { + // return make_Coord(0, threadIdx.y, threadIdx.x, 0); + // } + // }; + // typedef ThreadOffset_ ThreadOffset; + + /// Strides for immediate offset computation + typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides; + + /// Access size + static int const kAccessSize = AccessSize; }; /////////////////////////////////////////////////////////////////////////////////////////////////// +/// Functor computing a predicate given the logical position of an access +template +struct RegularTilePredicateFunctor { + typedef Delta_ Delta; + + /// Dimensions of the bounding volume + Coord<3> bounds; + + /// Constructs a predicate functor given the bounds of a tensor + CUTLASS_HOST_DEVICE + RegularTilePredicateFunctor(Coord<3> _bounds) : bounds(_bounds) {} + + /// Computes the predicate given the logical position of an access + CUTLASS_HOST_DEVICE + bool operator()(Coord<3> iteration, Coord<3> offset) const { + return (iteration[0] * Delta::kD + offset[0] < bounds[0]) && + (iteration[1] * Delta::kH + offset[1] < bounds[1]) && + (iteration[2] * Delta::kW + offset[2] < bounds[2]); + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct DumpType {}; /// Iterator for accessing a stripmined tile in memory template > struct TileIteratorBase { /// concept TileTraits @@ -117,7 +155,7 @@ struct TileIteratorBase { static IteratorAdvance::Kind const kAdvance = Advance_; /// Specifies iterator storage fragment type (Scalar or WmmaMatrix) - static IteratorFragment::Kind const kIteratorFragment = IteratorFragment_; + static FragmentElementType::Kind const kFragmentElementType = FragmentElementType_; /// Source or destination memory space static MemorySpace::Kind const kMemorySpace = MemorySpace; @@ -144,18 +182,19 @@ struct TileIteratorBase { typedef typename Traits::ThreadOffset ThreadOffset; /// The number of scalars accessed per load/store. - static int const kAccessSize = Tile::kC; + static int const kAccessSize = Traits::kAccessSize; /// The elements loaded/store by one instruction. typedef typename Vectorize::Type AccessType; /// The size of storage needed per fragment static int const kFragmentSize = - (kIteratorFragment == IteratorFragment::kWmmaMatrix ? 16 : sizeof(AccessType)); + (kFragmentElementType == FragmentElementType::kWmmaMatrix ? 16 : sizeof(AccessType)); /// The storage. typedef Fragment::kCount, kFragmentSize> Storage; /// The fragment. typedef Fragment::kCount * kAccessSize> Fragment; + /// The fragment iterator. typedef FragmentIterator FragmentIterator; /// The fragment const iterator. @@ -172,25 +211,61 @@ struct TileIteratorBase { /// Parameters to the iterator struct Params { - Index stride_d; + + // + // Dat members + // + + long long stride_d; Index stride_h; Index stride_w; - Index inc_d; + long long inc_d; Index inc_h; Index inc_w; - Index inc_advance; + long long inc_advance; + + // + // Methods + // + + /// Constructs params + CUTLASS_HOST_DEVICE + Params() : stride_d(0), stride_h(0), stride_w(0), inc_d(0), inc_h(0), inc_w(0) {} + + /// Constructs params + CUTLASS_HOST_DEVICE + Params(long long _stride_d, + Index _stride_h, + Index _stride_w, + long long _inc_d, + Index _inc_h, + Index _inc_w, + long long _inc_advance) + : stride_d(_stride_d), + stride_h(_stride_h), + stride_w(_stride_w), + inc_d(_inc_d), + inc_h(_inc_h), + inc_w(_inc_w), + inc_advance(_inc_advance) {} + + /// Constructs params with a stride vector + CUTLASS_HOST_DEVICE + Params(Coord<4> const &stride) { + initialize(stride); + } /// Initializes params CUTLASS_HOST_DEVICE - int initialize(Index _stride_d, + int initialize(long long _stride_d, Index _stride_h, Index _stride_w, - Index _inc_d, + long long _inc_d, Index _inc_h, Index _inc_w, - Index _inc_advance) { + long long _inc_advance) { stride_d = _stride_d; stride_h = _stride_h; stride_w = _stride_w; @@ -203,61 +278,79 @@ struct TileIteratorBase { return 0; } + /// Initializes the parameters object from a vector of strides CUTLASS_HOST_DEVICE - int initialize(Index _stride_d, Index _stride_h, Index _stride_w) { + int initialize(Coord<4> const &stride) { + return initialize(stride[0], stride[1], stride[2]); + } + + /// Initializes the parameters object from a vector of strides + CUTLASS_HOST_DEVICE + int initialize(long long _stride_d, Index _stride_h, Index _stride_w) { stride_d = _stride_d; stride_h = _stride_h; stride_w = _stride_w; inc_w = stride_w * Delta::kW; inc_h = stride_h * Delta::kH - stride_w * Delta::kW * (Iterations::kW - 1); + inc_d = stride_d * Delta::kD - stride_h * Delta::kH * (Iterations::kH - 1) - + stride_w * Delta::kW * (Iterations::kW - 1); + + inc_advance = 0; if (kAdvance == IteratorAdvance::kH) { // Advance in the H dimension. - inc_d = 0; + inc_advance = Tile::kH * stride_h; } else if (kAdvance == IteratorAdvance::kW) { // Advance in the W dimension. - inc_d = stride_w * Tile::kW - stride_h * Tile::kH; + inc_advance = Tile::kW * stride_w; + } else { // Advance in the D dimension. - inc_d = stride_d; + inc_advance = Tile::kD * stride_d; } - inc_advance = 0; + inc_advance -= stride_d * Delta::kD * (Iterations::kD - 1) + + stride_h * Delta::kH * (Iterations::kH - 1) + + stride_w * Delta::kW * (Iterations::kW - 1); return 0; } + /// Gotta have this CUTLASS_HOST_DEVICE int initialize() { stride_d = 0; stride_h = 0; stride_w = 1; - inc_d = inc_h = inc_w = inc_advance = 0; + inc_advance = 0; + inc_d = inc_h = inc_w = 0; return 0; } }; /// Is the iterator valid? - CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; } + CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; } // // Static function members // /// Initializes a predicate vector - template - CUTLASS_DEVICE static void initialize_predicates(PredicateIterator predicate_it, - Coord<3> const &bounds, - Coord<3> const &offset = make_Coord(0, 0, 0)) { + template + CUTLASS_HOST_DEVICE static void initialize_predicates(PredicateIterator predicate_it, + PredicateFunctor const &predicate_func, + Coord<3> const &offset) { + CUTLASS_PRAGMA_UNROLL for (int d = 0; d < Iterations::kD; ++d) { - bool enable_d = (d * Delta::kD + offset[0] < bounds[0]); + CUTLASS_PRAGMA_UNROLL for (int h = 0; h < Iterations::kH; ++h) { - bool enable_h = (h * Delta::kH + offset[1] < bounds[1]); + CUTLASS_PRAGMA_UNROLL for (int w = 0; w < Iterations::kW; ++w) { - bool enable_w = (w * Tile::kC * Delta::kW + offset[2] < bounds[2]); - predicate_it.set(d, h, w, 0, enable_d && enable_h && enable_w); + bool enable = predicate_func(make_Coord(d, h, w), offset); + predicate_it.set(enable); + ++predicate_it; } } } @@ -301,7 +394,7 @@ template > struct TileLoadIterator : public TileIteratorBase { /// Base class typedef TileIteratorBase Base; @@ -329,13 +422,13 @@ struct TileLoadIterator : public TileIteratorBase TensorRef; + /// Parameters struct Params : public BaseParams { /// Pointer to memory Scalar const *pointer; + // + // Methods + // + + /// Initialize params to access storage object + CUTLASS_HOST_DEVICE + Params() : pointer(0){ Base::Params::initialize(); } + + /// Initialize params to access storage object + CUTLASS_HOST_DEVICE + Params(Scalar const *ptr) : pointer(ptr) { Base::Params::initialize(); } + + /// Constructs with a CompactTensorRef<> + CUTLASS_HOST_DEVICE + Params(TensorRef const &ref): pointer(ref.data()) { + Base::Params::initialize(ref.stride()); + } + + /// Initialize params to access storage object + CUTLASS_HOST_DEVICE + Params(Scalar const *ptr, + long long _stride_d, + Index _stride_h, + Index _stride_w, + long long _inc_d, + Index _inc_h, + Index _inc_w, + Index _inc_advance) + : pointer(ptr) { + Base::Params::initialize( + _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance); + } + + /// Initialize params to access storage object + CUTLASS_HOST_DEVICE + Params(Scalar const *ptr, long long stride_d, Index stride_h, Index stride_w) + : pointer(ptr) { + Base::Params::initialize(stride_d, stride_h, stride_w); + } + + /// Initializes params to access a raw pointer + CUTLASS_HOST_DEVICE + int initialize(TensorRef const &ref) { + pointer = ref.data(); + return Base::Params::initialize(ref.stride()); + } + /// Initialize params to access storage object CUTLASS_HOST_DEVICE int initialize(SharedStorage const &storage) { pointer = &storage[0]; + Base::Params::initialize(); + return 0; + } + + /// Initialize params to access storage object + CUTLASS_HOST_DEVICE + int initialize(Scalar const *ptr) { + pointer = ptr; + Base::Params::initialize(); return 0; } /// Initializes params to access a raw pointer CUTLASS_HOST_DEVICE - int initialize(Scalar const *ptr, Index stride_d, Index stride_h, Index stride_w) { + int initialize(Scalar const *ptr, long long stride_d, Index stride_h, Index stride_w) { Base::Params::initialize(stride_d, stride_h, stride_w); pointer = ptr; return 0; @@ -411,10 +566,10 @@ struct TileLoadIterator : public TileIteratorBase + /// Initializes a predicate vector using a RegularTilePredicateFunctor + template < + /// Predicate iterator + typename PredicateIterator> CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord<3> const &bounds, Coord<3> const &block_offset = make_Coord(0, @@ -455,8 +612,23 @@ struct TileLoadIterator : public TileIteratorBase(bounds), + block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2])); + } + + /// Initializes a predicate vector using an arbitrary predicate functor + template < + /// Predicate iterator + typename PredicateIterator, + /// Functor computing predicates + typename PredicateFunctor> + CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, + PredicateFunctor const &functor, + Coord<3> const &block_offset) { + Base::initialize_predicates( + predicate_it, + functor, + block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2])); } // @@ -475,41 +647,27 @@ struct TileLoadIterator : public TileIteratorBase const &block_offset = make_Coord(0, 0, 0), ThreadOffset thread_offset_func = ThreadOffset()) : stage(0) { - int const offset = thread_offset_func()[2]; - params.pointer = &shared_storage[offset]; - } + params.pointer = ptr + thread_offset_func()[2]; - /// Returns the current pointer - CUTLASS_HOST_DEVICE - Scalar const *data() const { return params.pointer; } + params.stride_d = 0; + params.stride_h = 0; + params.stride_w = 1; - /// The accessor. - CUTLASS_DEVICE void get(AccessType &value, int d, int h, int w, int c) const { - int const imm = - ComputeOffsetFromStrides::get(d, h, w, c); - Load::load(value, params.pointer, imm); + params.inc_d = params.inc_h = params.inc_w = params.inc_advance = 0; } /// Increment in the D dimension @@ -524,8 +682,21 @@ struct TileLoadIterator : public TileIteratorBase::get(d, h, w, c); + Load::load(value, params.pointer, offset); + } + /// Increment the stage. - CUTLASS_DEVICE void inc_stage() { + CUTLASS_HOST_DEVICE void inc_stage() { if (Tile::kD > 1) { int const kStageSize = Tile::kH * Tile::kW * Tile::kC; if (stage == Tile::kD - 1) { @@ -538,7 +709,27 @@ struct TileLoadIterator : public TileIteratorBase const &offset) { + long long _offset = offset.template dot( + make_Coord(params.stride_d, params.stride_h, params.stride_w) + ); + + params.pointer += _offset; + return *this; + } + + /// Adds a raw offset to the pointer + CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; } + + CUTLASS_HOST_DEVICE Index stride_advance(void) { + Index stride = params.stride_h; + if (kAdvance == IteratorAdvance::kW) { + stride = params.stride_w; + } + return stride; + } + /// Loads a fragment and advances the iterator to the next tile. template CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) { @@ -547,11 +738,12 @@ struct TileLoadIterator : public TileIteratorBase::load( - reinterpret_cast(frag_iterator.at(d, h, w, 0)), data(), 0); + for (int c = 0; c < Iterations::kC; ++c) { + if (*pred_it) { + load_element( + reinterpret_cast(frag_iterator.at(d, h, w, c)), d, h, w, c); + } } - if (w < Iterations::kW - 1) { inc_w(); } @@ -587,6 +779,19 @@ struct TileLoadIterator : public TileIteratorBase + CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) { + FragmentIterator frag_iterator(fragment); + for (int h = 0; h < Iterations::kH; ++h) { + for (int w = 0; w < Iterations::kW; ++w) { + for (int c = 0; c < Iterations::kC; ++c) { + load_element(reinterpret_cast(frag_iterator.at(0, h, w, c)), d, h, w, c); + } + } + } + } }; /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -626,7 +831,7 @@ template > struct TileStoreIterator : public TileIteratorBase { /// Base class typedef TileIteratorBase Base; @@ -660,11 +865,14 @@ struct TileStoreIterator : public TileIteratorBase TensorRef; + /// Parameters struct Params : public BaseParams { /// Pointer to memory Scalar *pointer; + // + // Methods + // + + // Default constructor + CUTLASS_HOST_DEVICE + Params() : pointer(0) {} + + // Default constructor + CUTLASS_HOST_DEVICE + Params(Scalar *ptr) : pointer(ptr) { Base::Params::initialize(); } + + /// Constructs with a CompactTensorRef<> + CUTLASS_HOST_DEVICE + Params(TensorRef const &ref): pointer(ref.data()) { + Base::Params::initialize(ref.stride()); + } + + // Default constructor + CUTLASS_HOST_DEVICE + Params(Scalar *ptr, long long stride_d, Index stride_h, Index stride_w) { + initialize(ptr, stride_d, stride_h, stride_w); + } + + // Default constructor + CUTLASS_HOST_DEVICE + Params(Scalar *ptr, + long long _stride_d, + Index _stride_h, + Index _stride_w, + long long _inc_d, + Index _inc_h, + Index _inc_w, + Index _inc_advance) { + initialize(ptr, _stride_d, _stride_h, _stride_w, _inc_d, _inc_h, _inc_w, _inc_advance); + } + /// Initialize params to access storage object CUTLASS_HOST_DEVICE int initialize(SharedStorage &storage) { pointer = &storage[0]; - return 0; + return Base::Params::initialize(); + } + + /// Initialize params to access storage object + CUTLASS_HOST_DEVICE + int initialize(Scalar *ptr) { + pointer = ptr; + return Base::Params::initialize(); } /// Initializes params to access a raw pointer CUTLASS_HOST_DEVICE - int initialize(Scalar *ptr, Index stride_d, Index stride_h, Index stride_w) { + int initialize(Scalar *ptr, long long stride_d, Index stride_h, Index stride_w) { Base::Params::initialize(stride_d, stride_h, stride_w); pointer = ptr; return 0; @@ -730,10 +988,10 @@ struct TileStoreIterator : public TileIteratorBase + /// Initializes a predicate vector using a RegularTilePredicateFunctor + template < + /// Predicate iterator + typename PredicateIterator> CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, Coord<3> const &bounds, Coord<3> const &block_offset = make_Coord(0, @@ -774,8 +1034,23 @@ struct TileStoreIterator : public TileIteratorBase(bounds), + block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2])); + } + + /// Initializes a predicate vector using an arbitrary predicate functor + template < + /// Predicate iterator + typename PredicateIterator, + /// Functor computing predicates + typename PredicateFunctor> + CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, + PredicateFunctor const &functor, + Coord<3> const &block_offset) { + Base::initialize_predicates( + predicate_it, + functor, + block_offset + make_Coord(thread_offset[0], thread_offset[1], thread_offset[2])); } // @@ -794,25 +1069,22 @@ struct TileStoreIterator : public TileIteratorBase const &block_offset = make_Coord(0, 0, 0), - ThreadOffset thread_offset_func = ThreadOffset()) + TileStoreIterator(Params const &, Scalar *ptr, ThreadOffset thread_offset_func = ThreadOffset()) : stage(0) { - int const offset = thread_offset_func()[2]; - params.pointer = &shared_storage[offset]; - } + params.pointer = ptr + thread_offset_func()[2]; + params.stride_d = 0; + params.stride_h = 0; + params.stride_w = 1; - /// Returns the current pointer - CUTLASS_HOST_DEVICE - Scalar *data() const { return params.pointer; } + params.inc_d = params.inc_h = params.inc_w = params.inc_advance = 0; + } /// Increment in the D dimension CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; } @@ -827,7 +1099,7 @@ struct TileStoreIterator : public TileIteratorBase 1) { int const kStageSize = Tile::kH * Tile::kW * Tile::kC; if (stage == Tile::kD - 1) { @@ -840,25 +1112,43 @@ struct TileStoreIterator : public TileIteratorBase::get(d, h, w, c); - Store::store(value, params.pointer, imm); + /// Adds a vector offset to the iterator + CUTLASS_HOST_DEVICE TileStoreIterator & operator+=(Coord<3> const &offset) { + params.pointer += offset.template dot( + make_Coord(params.stride_d, params.stride_h, params.stride_w) + ); + return *this; + } + + /// Adds a raw offset to the pointer + CUTLASS_HOST_DEVICE void add_pointer_offset(Index offset) { params.pointer += offset; } + + /// Stores a single fragment element into memory. + CUTLASS_HOST_DEVICE void store_element(AccessType const &value, int d, int h, int w, int c) { + int const offset = + ComputeOffsetFromStrides::get(d, h, w, c); + Store::store(value, params.pointer, offset); } - public: /// Stores a fragment and advances to the next tile. template - CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment, PredicateIterator pred_it) { - FragmentIterator frag_iterator(fragment); + CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it) { + FragmentConstIterator frag_iterator(fragment); for (int d = 0; d < Iterations::kD; ++d) { for (int h = 0; h < Iterations::kH; ++h) { for (int w = 0; w < Iterations::kW; ++w, ++pred_it) { - if (*pred_it) { - Store::store( - reinterpret_cast(frag_iterator.at(d, h, w, 0)), data(), 0); + for (int c = 0; c < Iterations::kC; ++c) { + if (*pred_it) { + store_element( + reinterpret_cast(frag_iterator.at(d, h, w, c)), d, h, w, c); + } } if (w < Iterations::kW - 1) { inc_w(); @@ -877,23 +1167,103 @@ struct TileStoreIterator : public TileIteratorBase - CUTLASS_HOST_DEVICE void store_post_increment(Fragment &fragment) { + CUTLASS_HOST_DEVICE void store_post_increment(Fragment const &fragment) { typename PredicateVector::TrivialIterator pred_it; store_post_increment(fragment, pred_it); } /// Stores a fragment without advancing the iterator. template - CUTLASS_HOST_DEVICE void store(Fragment &fragment, PredicateIterator pred_it) const { + CUTLASS_HOST_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const { TileStoreIterator _store_it(*this); _store_it.store_post_increment(fragment, pred_it); } /// Stores a fragment without advancing the iterator. template - CUTLASS_HOST_DEVICE void store(Fragment &fragment) const { + CUTLASS_HOST_DEVICE void store(Fragment const &fragment) const { typename PredicateVector::TrivialIterator pred_it; store(fragment, pred_it); } + + /// Loads a single fragment element from memory + CUTLASS_HOST_DEVICE void load_element(AccessType &value, int d, int h, int w, int c) const { + int const offset = + ComputeOffsetFromStrides::get(d, h, w, c); + + Load::load(value, params.pointer, offset); + } + + /// Loads a fragment and advances the iterator to the next tile. + template + CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) { + FragmentIterator frag_iterator(fragment); + + for (int d = 0; d < Iterations::kD; ++d) { + for (int h = 0; h < Iterations::kH; ++h) { + for (int w = 0; w < Iterations::kW; ++w, ++pred_it) { + for (int c = 0; c < Iterations::kC; ++c) { + if (*pred_it) { + load_element( + reinterpret_cast(frag_iterator.at(d, h, w, c)), d, h, w, c); + } + } + if (w < Iterations::kW - 1) { + inc_w(); + } + } + if (h < Iterations::kH - 1) { + inc_h(); + } + } + if (d < Iterations::kD - 1) { + inc_d(); + } + } + inc_advance(); + } + + /// Loads a fragment and advances the iterator to the next tile. + template + CUTLASS_HOST_DEVICE void load_post_increment(Fragment &fragment) { + typename PredicateVector::TrivialIterator pred_it; + load_post_increment(fragment, pred_it); + } + + /// Loads a fragment without advancing the iterator.. + template + CUTLASS_HOST_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const { + TileStoreIterator _load_it(*this); + _load_it.load_post_increment(fragment, pred_it); + } + + /// Loads a fragment without advancing the iterator.. + template + CUTLASS_HOST_DEVICE void load(Fragment &fragment) const { + typename PredicateVector::TrivialIterator pred_it; + load(fragment, pred_it); + } + + /// Loads a fragment without advancing the iterator.. + template + CUTLASS_HOST_DEVICE void load(Fragment &fragment, int d) { + FragmentIterator frag_iterator(fragment); + for (int h = 0; h < Iterations::kH; ++h) { + for (int w = 0; w < Iterations::kW; ++w) { + for (int c = 0; c < Iterations::kC; ++c) { + load_element(reinterpret_cast(frag_iterator.at(0, h, w, c)), d, h, w, c); + } + } + } + } }; -} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/cutlass/tile_stream.h b/cutlass/tile_stream.h new file mode 100644 index 000000000..7790605a0 --- /dev/null +++ b/cutlass/tile_stream.h @@ -0,0 +1,378 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements the tile stream concept, composing an iterator with a transformation. Offers + split-phase semantics, separating the initiation of an asynchronous memory operation with a + fence forcing it to complete. +*/ +#pragma once + +// clang-format off + +#include "cutlass/convert.h" +#include "cutlass/tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic stream for loading and transforming fragments +template > +struct TileLoadStream { + // + // Type definitions + // + + /// TileLoadIterator + typedef Iterator_ Iterator; + + /// Transformer + typedef Transformer_ Transformer; + + /// Fragment fetched from source memory + typedef typename Iterator::Fragment Fragment; + + /// Output fragment from transformer + typedef typename Transformer::OutputFragment TransformedFragment; + + /// Tensor reference expected by the stream + typedef typename Iterator::TensorRef TensorRef; + + /// Empty predicate vector struct + struct PredicateVector {}; + + /// Index type + typedef typename Iterator::Index Index; + + /// Parameters object used to construct generic load stream + struct Params { + /// Parameters to the iterator + typename Iterator::Params iterator; + + // + // Methods + // + + /// Default constructor + CUTLASS_HOST_DEVICE + Params() {} + + /// Constructor with iterator params + CUTLASS_HOST_DEVICE + Params(typename Iterator::Params const &_iterator) : iterator(_iterator) {} + }; + + // + // Data members + // + + /// Iterator to load tiles + Iterator iterator; + + /// Fragment loaded via iterator + Fragment fetched_fragment; + + /// Transformation applied to fragments + Transformer transformer; + + /// Transformed fragment from transformer + TransformedFragment transformed_fragment; + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + TileLoadStream(Params const &_params, TensorRef const &_ref) + : iterator(_params.iterator, _ref) {} + + /// Ctor + CUTLASS_DEVICE + TileLoadStream(Params const &_params, + Coord<3> const &threadblock_offset = make_Coord(0, 0, 0) + ): iterator(_params.iterator, threadblock_offset) { } + + /// Loads a tile and increments the iterator + CUTLASS_DEVICE + void copy() { iterator.load_post_increment(fetched_fragment); } + + /// Commits the fetched fragment and applies a transformation + CUTLASS_DEVICE + void commit() { transformer.transform(fetched_fragment, transformed_fragment); } + + /// Accesses the loaded, transformed fragment + CUTLASS_DEVICE + Fragment &intermediate_fragment() { return fetched_fragment; } + + /// Accesses the loaded, transformed fragment + CUTLASS_DEVICE + TransformedFragment &fragment() { return transformed_fragment; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic stream for transforming and storing fragments +template > +struct TileStoreStream { + // + // Type definitions + // + + /// TileLoadIterator + typedef Iterator_ Iterator; + + /// Transformer + typedef Transformer_ Transformer; + + /// Source fragment + typedef typename Transformer::InputFragment Fragment; + + /// Transformed fragment, compatible with Iterator::Fragment + typedef typename Transformer::OutputFragment TransformedFragment; + + /// Tensor reference expected by the underlying iterator + typedef typename Iterator::TensorRef TensorRef; + + /// Empty predicate vector struct + struct PredicateVector {}; + + /// Index type + typedef typename Iterator::Index Index; + + /// Parameters used to construct the stream + struct Params { + /// Parameters to the iterator + typename Iterator::Params iterator; + + // + // Methods + // + + /// Default constructor + CUTLASS_HOST_DEVICE + Params() {} + + /// Constructor with iterator params + CUTLASS_HOST_DEVICE + Params(typename Iterator::Params const &_iterator) : iterator(_iterator) {} + }; + + // + // Data members + // + + /// Iterator to store tiles + Iterator iterator; + + /// Transformation applied to inputs + Transformer transformer; + + /// Source fragment + Fragment source_fragment; + + /// Transformed fragment from transformer + TransformedFragment transformed_fragment; + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + TileStoreStream(Params const &_params, TensorRef const &_ref) + : iterator(_params.iterator, _ref) {} + + /// Ctor + CUTLASS_DEVICE + TileStoreStream(Params const &_params, + Coord<3> const &threadblock_offset = make_Coord(0, 0, 0) + ): iterator(_params.iterator, threadblock_offset) { } + + /// Stores a fragment and increments the iterator + CUTLASS_DEVICE + void copy() { + + transformer.transform(source_fragment, transformed_fragment); + iterator.store_post_increment(transformed_fragment); + } + + /// Stores a fragment and increments the iterator + CUTLASS_DEVICE + void copy(Fragment const &frag) { + source_fragment = frag; + copy(); + } + + /// Commits the store operation + CUTLASS_DEVICE + void commit() {} + + /// Accesses the transformed fragment + CUTLASS_DEVICE + Fragment &fragment() { return source_fragment; } + + /// Accesses the fragment after trasnforming + CUTLASS_DEVICE + TransformedFragment &intermediate_fragment() { return transformed_fragment; } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic stream for loading and transforming fragments +template , + typename Transformer_ = Copy > +struct PredicatedTileLoadStream : public TileLoadStream { + // + // Type definitions + // + + typedef TileLoadStream Base; + + /// TileLoadIterator + typedef Iterator_ Iterator; + + /// Predicate functor + typedef PredicateFunctor_ PredicateFunctor; + + /// Transformer + typedef Transformer_ Transformer; + + /// Fragment fetched from source memory + typedef typename Base::Fragment Fragment; + + /// Output fragment from transformer + typedef typename Base::TransformedFragment TransformedFragment; + + /// Parameters object used to construct generic load stream + typedef typename Base::Params Params; + + // + // Data members + // + + /// Predicates + typename Iterator::PredicateVector predicates; + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + PredicatedTileLoadStream(Params const &_params, + Coord<3> const &bounds, + Coord<3> const &threadblock_offset = make_Coord(0, 0, 0)) + : Base(_params, threadblock_offset) { + this->iterator.initialize_predicates( + predicates.begin(), PredicateFunctor(bounds), threadblock_offset); + } + + /// Loads a tile and increments the iterator + CUTLASS_DEVICE + void copy() { this->iterator.load_post_increment(this->fetched_fragment, predicates.begin()); } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Generic stream for transforming and storing fragments +template , + typename Transformer_ = Copy > +struct PredicatedTileStoreStream : public TileStoreStream { + // + // Type definitions + // + + typedef TileStoreStream Base; + + /// TileLoadIterator + typedef Iterator_ Iterator; + + /// Predicate functor + typedef PredicateFunctor_ PredicateFunctor; + + /// Transformer + typedef Transformer_ Transformer; + + /// Fragment fetched from source memory + typedef typename Base::Fragment Fragment; + + /// Output fragment from transformer + typedef typename Base::TransformedFragment TransformedFragment; + + /// Parameters object used to construct generic load stream + typedef typename Base::Params Params; + + // + // Data members + // + + /// Predicates + typename Iterator::PredicateVector predicates; + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + PredicatedTileStoreStream(Params const &_params, + Coord<3> const &bounds, + Coord<3> const &threadblock_offset = make_Coord(0, 0, 0)) + : Base(_params, threadblock_offset) { + this->iterator.initialize_predicates( + predicates.begin(), PredicateFunctor(bounds), threadblock_offset); + } + + /// Stores the fragment and increments the iterator + CUTLASS_DEVICE + void copy() { + this->transformer.transform(this->source_fragment, this->transformed_fragment); + this->iterator.store_post_increment(this->transformed_fragment, predicates.begin()); + } + + /// Stores the fragment and increments the iterator + CUTLASS_DEVICE + void copy(Fragment const &frag) { + this->source_fragment = frag; + copy(); + } + + /// Commits the store operation + CUTLASS_DEVICE + void commit() {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +// clang-format on diff --git a/cutlass/tile_traits_standard.h b/cutlass/tile_traits_standard.h index 14ecd01ab..9145c5da9 100644 --- a/cutlass/tile_traits_standard.h +++ b/cutlass/tile_traits_standard.h @@ -28,7 +28,7 @@ */ #pragma once -#include +#include "cutlass/tile_iterator.h" namespace cutlass { @@ -204,6 +204,9 @@ struct TileTraitsStandard { /// Number of participating warps static int const kWarpCount = kThreads / kWarpSize; + /// By default, do not do scalar loads + static int const kAccessSize = 1; + // Static assertions static_assert(!(ShapeCount::kDhw % kThreads), "Tiling undefined if elements not divisible by threads."); @@ -223,8 +226,7 @@ struct TileTraitsStandard { typedef typename Traits::Delta Delta; /// Delta between each thread's access - /// TODO MTA this is wrong for sure, but Delta is used for stride computation at the moment - typedef Delta ImmediateOffsetStrides; + typedef Shape<0, 0, 0, 0> ImmediateOffsetStrides; /// Number of accesses typedef typename Traits::Iterations Iterations; diff --git a/cutlass/util/complex.h b/cutlass/util/complex.h new file mode 100644 index 000000000..260a3abd2 --- /dev/null +++ b/cutlass/util/complex.h @@ -0,0 +1,457 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include "cutlass/cutlass.h" +#include + +namespace cutlass { +namespace platform { + +////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Accessors for CUDA complex types +// + +/// Returns the real part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +CUTLASS_HOST_DEVICE +float const &real(cuFloatComplex const &z) { return z.x; } + +/// Returns the real part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +CUTLASS_HOST_DEVICE +float &real(cuFloatComplex &z) { return z.x; } + +/// Returns the real part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +CUTLASS_HOST_DEVICE +double const &real(cuDoubleComplex const &z) { return z.x; } + +/// Returns the real part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +CUTLASS_HOST_DEVICE +double &real(cuDoubleComplex &z) { return z.x; } + +/// Returns the imaginary part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +CUTLASS_HOST_DEVICE +float const &imag(cuFloatComplex const &z) { return z.y; } + +/// Returns the imaginary part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +CUTLASS_HOST_DEVICE +float &imag(cuFloatComplex &z) { return z.y; } + +/// Returns the imaginary part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +CUTLASS_HOST_DEVICE +double const &imag(cuDoubleComplex const &z) { return z.y; } + +/// Returns the imaginary part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +CUTLASS_HOST_DEVICE +double &imag(cuDoubleComplex &z) { return z.y; } + +////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Class for representing and manipulating complex numbers with conversions from built-in CUDA +/// complex types. +template +class complex { + public: + /// Type alias for scalar type + typedef T value_type; + + private: + // + // Data members + // + + /// Real part + T _real; + + /// Imaginary part + T _imag; + + public: +// +// Methods +// + +/// Constructor +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + complex(T r = T(0), T i = T(0)) : _real(r), _imag(i) {} + +/// Conversion from cuFloatComplex +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + complex(cuFloatComplex const &z) : _real(platform::real(z)), _imag(platform::imag(z)) {} + +/// Conversion from cuDoubleComplex +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + complex(cuDoubleComplex const &z) : _real(platform::real(z)), _imag(platform::imag(z)) {} + +/// Accesses the real part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + T const &real() const { return _real; } + +/// Accesses the real part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + T &real() { return _real; } + +/// Accesses the imaginary part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + T const &imag() const { return _imag; } + +/// Accesses the imaginary part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + T &imag() { return _imag; } + +/// Converts to cuFloatComplex +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + operator cuFloatComplex() const { return make_cuFloatComplex(real(), imag()); } + +/// Converts to cuDoubleComplex +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type + CUTLASS_HOST_DEVICE + operator cuDoubleComplex() const { return make_cuDoubleComplex(real(), imag()); } +}; + +// +// Accessors for complex template +// + +/// Returns the real part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE T const &real(complex const &z) { + return z.real(); +} + +/// Returns the real part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE T &real(complex &z) { + return z.real(); +} + +/// Returns the imaginary part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE T const &imag(complex const &z) { + return z.imag(); +} + +/// Returns the imaginary part of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE T &imag(complex &z) { + return z.imag(); +} + +// +// Output operators +// + +template +std::ostream &operator<<(std::ostream &out, complex const &z) { + T _r = real(z); + T _i = imag(z); + return out << _r << "+i" << _i; +} + +// +// Non-member operators defined for complex types +// + +/// Equality operator +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE bool operator==(complex const &lhs, complex const &rhs) { + return real(lhs) == (rhs) && imag(lhs) == imag(rhs); +} + +/// Inequality operator +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE bool operator!=(complex const &lhs, complex const &rhs) { + return !(lhs == rhs); +} + +/// Addition +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex operator+(complex const &lhs, complex const &rhs) { + return complex(real(lhs) + real(rhs), imag(lhs) + imag(rhs)); +} + +/// Subtraction +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex operator-(complex const &lhs, complex const &rhs) { + return complex(real(lhs) - real(rhs), imag(lhs) - imag(rhs)); +} + +/// Multiplication +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex operator*(complex const &lhs, complex const &rhs) { + return complex(real(lhs) * real(rhs) - imag(lhs) * imag(rhs), + real(lhs) * imag(rhs) + imag(lhs) * real(rhs)); +} + +/// Scalar Multiplication +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex operator*(complex const &lhs, T const &s) { + return complex(real(lhs) * s, imag(lhs) * s); +} + +/// Scalar Multiplication +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex operator*(T const &s, complex const &rhs) { + return complex(s * real(rhs), s * imag(rhs)); +} + +/// Division +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex operator/(complex const &lhs, complex const &rhs) { + T d = (real(rhs) * (rhs) + imag(rhs) * imag(rhs)); + + return complex((real(lhs) * (rhs) + imag(lhs) * imag(rhs)) / d, + (imag(lhs) * (rhs)-real(lhs) * imag(rhs)) / d); +} + +/// Scalar Division +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex operator/(complex const &lhs, T const &s) { + return complex(real(lhs) / s, imag(lhs) / s); +} + +/// Scalar divided by complex +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex operator/(T const &s, complex const &rhs) { + T d = (real(rhs) * (rhs) + imag(rhs) * imag(rhs)); + + return complex((s * (rhs)) / d, -(s * imag(rhs)) / d); +} + +/// Addition +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex &operator+=(complex &lhs, complex const &rhs) { + lhs = (lhs + rhs); + return lhs; +} + +/// Subtraction +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex &operator-=(complex &lhs, complex const &rhs) { + lhs = (lhs - rhs); + return lhs; +} + +/// Multiplication +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex &operator*=(complex &lhs, complex const &rhs) { + lhs = (lhs * rhs); + return lhs; +} + +/// Scalar multiplication +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex &operator*=(complex &lhs, T s) { + lhs = (lhs * s); + return lhs; +} + +/// Division +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex &operator/=(complex &lhs, complex const &rhs) { + lhs = (lhs / rhs); + return lhs; +} + +// +// Non-member functions defined for complex numbers +// + +/// Returns the magnitude of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE T abs(complex const &z) { + return sqrt(norm(z)); +} + +/// Returns the magnitude of the complex number +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE T arg(complex const &z) { + return atan2(imag(z), real(z)); +} + +/// Returns the squared magnitude +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE T norm(complex const &z) { + return real(z) * real(z) + imag(z) * imag(z); +} + +/// Returns the complex conjugate +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex conj(complex const &z) { + return complex(real(z), -imag(z)); +} + +/// Projects the complex number z onto the Riemann sphere +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex proj(complex const &z) { + T d = real(z) * real(z) + imag(z) * imag(z) + T(1); + return complex((T(2) * real(z)) / d, (T(2) * imag(z)) / d); +} + +/// Returns a complex number with magnitude r and phase theta +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex polar(T const &r, T const &theta = T()) { + return complex(r * cos(theta), r * sin(theta)); +} + +/// Computes the complex exponential of z. +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex exp(complex const &z) { + return complex(real(z) * cos(imag(z)), real(z) * sin(imag(z))); +} + +/// Computes the complex exponential of z. +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex log(complex const &z) { + return complex(log(abs(z)), arg(z)); +} + +/// Computes the complex exponential of z. +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex log10(complex const &z) { + return log(z) / T(log(T(10))); +} + +/// Computes the square root of complex number z +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex sqrt(complex const &z) { + return sqrt(T(2)) / T(2) * + complex(sqrt(sqrt(norm(z)) + real(z)), + (imag(z) < 0 ? T(-1) : T(1)) * sqrt(sqrt(norm(z)) - real(z))); +} + +/// Computes the cosine of complex z. +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex cos(complex const &z) { + return (exp(z) + exp(-z)) / T(2); +} + +/// Computes the sin of complex z. +#pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex with a + // host-only type +template +CUTLASS_HOST_DEVICE complex sin(complex const &z) { + return (exp(-z) - exp(z)) * complex(T(0), T(1) / T(2)); +} + +////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace platform +} // namespace cutlass diff --git a/cutlass/util/cutlass_math.h b/cutlass/util/cutlass_math.h index 0ecdc4382..e3b46ef35 100644 --- a/cutlass/util/cutlass_math.h +++ b/cutlass/util/cutlass_math.h @@ -30,7 +30,7 @@ * \brief Math utilities */ -#include +#include "cutlass/util/platform.h" namespace cutlass { @@ -128,4 +128,38 @@ CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) { return temp ? (a / temp * b) : 0; } +/** + * log2 computation, what's the + * difference between the below codes and + * log2_up/down codes? + */ +template +CUTLASS_HOST_DEVICE value_t clz(value_t x) { + for (int i = 31; i >= 0; --i) { + if ((1 << i) & x) return 31 - i; + } + return 32; +} + +template +CUTLASS_HOST_DEVICE value_t find_log2(value_t x) { + int a = 31 - clz(x); + a += (x & (x - 1)) != 0; // Round up, add 1 if not a power of 2. + return a; +} + +/****************************************************************************** + * Min/Max + ******************************************************************************/ + +template +struct Min { + static int const kValue = (A < B) ? A : B; +}; + +template +struct Max { + static int const kValue = (A > B) ? A : B; +}; + } // namespace cutlass diff --git a/cutlass/gemm/identity_block_swizzle.h b/cutlass/util/numeric_types.h similarity index 79% rename from cutlass/gemm/identity_block_swizzle.h rename to cutlass/util/numeric_types.h index e1bdb2e00..d8094a256 100644 --- a/cutlass/gemm/identity_block_swizzle.h +++ b/cutlass/util/numeric_types.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, are permitted * provided that the following conditions are met: @@ -22,27 +22,26 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ -/*! \file - \brief Defies functors for mapping blockIdx to partitions of the GEMM computation. - - Currently, we only implement an identity mapping. +/*! + \file + \brief */ #pragma once namespace cutlass { -namespace gemm { -//////////////////////////////////////////////////////////////////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////////////////////////// -struct IdentityBlockSwizzle { - /// Ctor. - CUTLASS_DEVICE IdentityBlockSwizzle() {} +// +// Definitions for 1-bit binary and 4-bit integer types +// - /// Swizzle the block index. - CUTLASS_DEVICE dim3 swizzle() { return blockIdx; } -}; +struct bin1_t {}; // 1-bit binary type -//////////////////////////////////////////////////////////////////////////////////////////////////// +struct int4_t {}; // 4-bit signed integer type + +struct uint4_t {}; // 4-bit unsigned integer type + +/////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace gemm } // namespace cutlass diff --git a/cutlass/util/platform.h b/cutlass/util/platform.h index 2a44c10e6..3fd7c897d 100644 --- a/cutlass/util/platform.h +++ b/cutlass/util/platform.h @@ -110,9 +110,17 @@ #include // For integral constants, conditional metaprogramming, and type traits #endif -#include +#include "cutlass/cutlass.h" #endif + +//----------------------------------------------------------------------------- +// OS +//----------------------------------------------------------------------------- +#if defined(WIN32) || defined(_WIN32) || defined(__WIN32) && !defined(__CYGWIN__) +#define CUTLASS_OS_WINDOWS +#endif + /****************************************************************************** * Macros ******************************************************************************/ diff --git a/cutlass/vector.h b/cutlass/vector.h index a66dfdef7..aeababb66 100644 --- a/cutlass/vector.h +++ b/cutlass/vector.h @@ -31,7 +31,8 @@ #include #endif -#include +#include "cutlass/util/numeric_types.h" +#include "cutlass/util/platform.h" namespace cutlass { @@ -80,13 +81,43 @@ union Vector { uint32_t registers[kRegisters]; /// Accessor to the ith lane. - CUTLASS_DEVICE Scalar const& operator[](uint32_t i) const { return scalars[i]; } + CUTLASS_HOST_DEVICE Scalar const& operator[](uint32_t i) const { return scalars[i]; } /// Accessor to the ith lane. - CUTLASS_DEVICE Scalar& operator[](uint32_t i) { return scalars[i]; } + CUTLASS_HOST_DEVICE Scalar& operator[](uint32_t i) { return scalars[i]; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// +template <> +union Vector { + /// The scalar type. + typedef half Scalar; + + /// The number of elements in the vector. + enum { kLanes = 1 }; + /// The size of the vector. + enum { kVectorSize = kLanes * (int)sizeof(Scalar) }; + /// The number of registers needed to store the vector. + enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 }; + + // Make sure that the vector type makes sense. + static_assert(kVectorSize <= 16, "Vector type is too large"); + + /// The aligned storage to make sure we have good alignment. + AlignedStruct aligned_; + /// The associated array of scalars. + uint16_t scalars[kLanes]; + + /// Accessor to the ith lane. + CUTLASS_HOST_DEVICE Scalar const& operator[](uint32_t i) const { + return reinterpret_cast(scalars[i]); + } + /// Accessor to the ith lane. + CUTLASS_HOST_DEVICE Scalar& operator[](uint32_t i) { + return reinterpret_cast(scalars[i]); + } +}; + #if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16) template @@ -112,19 +143,124 @@ union Vector { uint32_t registers[kRegisters]; /// Accessor to the ith lane. - CUTLASS_DEVICE Scalar const& operator[](uint32_t i) const { + CUTLASS_HOST_DEVICE Scalar const& operator[](uint32_t i) const { return reinterpret_cast(scalars[i]); } /// Accessor to the ith lane. - CUTLASS_DEVICE Scalar& operator[](uint32_t i) { return reinterpret_cast(scalars[i]); } + CUTLASS_HOST_DEVICE Scalar& operator[](uint32_t i) { + return reinterpret_cast(scalars[i]); + } }; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// +/// Vector definition for 1-bit binary datatype +template +union Vector { + /// The scalar type. + typedef bin1_t Scalar; + + /// The number of elements in the vector. + enum { kLanes = kLanes_ }; + /// The size of the vector. + enum { kVectorSize = kLanes / 8 }; + /// The number of registers needed to store the vector. + enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 }; + + static_assert((kLanes >= 8) && !(kLanes % 8), + "May only construct vectors of bin1_t that are multiples of 8 bits."); + + /// The aligned storage to make sure we have good alignment. + AlignedStruct aligned_; + /// The data in registers. + uint32_t registers[kRegisters]; + + /// Default Constructor + CUTLASS_HOST_DEVICE + Vector() {} + /// Constructor to convert from uint32_t type + CUTLASS_HOST_DEVICE Vector(uint32_t value) { registers[0] = value; } + /// Accessor to the ith lane. + CUTLASS_HOST_DEVICE bool operator[](uint32_t i) const { + return ( (registers[i / 32] & (1 << (i % 32))) != 0 ); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Vector definition for 4-bit signed integer datatype +template +union Vector { + /// The scalar type. + typedef int4_t Scalar; + + /// The number of elements in the vector. + enum { kLanes = kLanes_ }; + /// The size of the vector. + enum { kVectorSize = kLanes / 2 }; + /// The number of registers needed to store the vector. + enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 }; + + static_assert((kLanes >= 2) && !(kLanes % 2), + "May only construct vectors of int4_t that are multiples of 8 bits."); + + /// The aligned storage to make sure we have good alignment. + AlignedStruct aligned_; + /// The data in registers. + uint32_t registers[kRegisters]; + + /// Default Constructor + CUTLASS_HOST_DEVICE + Vector() {} + /// Constructor to convert from uint32_t type + CUTLASS_HOST_DEVICE Vector(uint32_t value) { registers[0] = value; } + /// Accessor to the ith lane. + CUTLASS_HOST_DEVICE int operator[](uint32_t i) const { + return (registers[i / 8] >> (i % 8 * 4) & 0x0f) + - 16 * (registers[i / 8] >> (i % 8 * 4 + 3) & 0x01); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Vector definition for 4-bit unsigned integer datatype +template +union Vector { + /// The scalar type. + typedef uint4_t Scalar; + + /// The number of elements in the vector. + enum { kLanes = kLanes_ }; + /// The size of the vector. + enum { kVectorSize = kLanes / 2 }; + /// The number of registers needed to store the vector. + enum { kRegisters = kVectorSize < 4 ? 1 : kVectorSize / 4 }; + + static_assert((kLanes >= 2) && !(kLanes % 2), + "May only construct vectors of uint4_t that are multiples of 8 bits."); + + /// The aligned storage to make sure we have good alignment. + AlignedStruct aligned_; + /// The data in registers. + uint32_t registers[kRegisters]; + + /// Default Constructor + CUTLASS_HOST_DEVICE + Vector() {} + /// Constructor to convert from uint32_t type + CUTLASS_HOST_DEVICE Vector(uint32_t value) { registers[0] = value; } + /// Accessor to the ith lane. + CUTLASS_HOST_DEVICE int operator[](uint32_t i) const { + return registers[i / 8] >> (i % 8 * 4) & 0x0f; + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template -CUTLASS_DEVICE void make_zero(Scalar_& x) { +CUTLASS_HOST_DEVICE void make_zero(Scalar_& x) { x = Scalar_(0); } @@ -137,15 +273,29 @@ struct Vectorize { //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct Vectorize { - typedef Element_ Type; +template +struct Vectorize, kLanes_> { + typedef Vector Type; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Vectorize, kLanes_> { + typedef Vector Type; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Vectorize, kLanes_> { + typedef Vector Type; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template -CUTLASS_DEVICE void make_zero(Vector& vec) { +CUTLASS_HOST_DEVICE void make_zero(Vector& vec) { for (int i = 0; i < Vector::kRegisters; ++i) { vec.registers[i] = 0; } diff --git a/cutlass/wmma_matrix.h b/cutlass/wmma_matrix.h index c4d8a0b54..61c4ed272 100644 --- a/cutlass/wmma_matrix.h +++ b/cutlass/wmma_matrix.h @@ -28,20 +28,23 @@ #pragma once #if defined(__CUDACC__) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 700) - -// Dependent header files should use the following macro to guard all code using -// nvcuda::wmma:: to enable compilation for CUDA Compute Capabilities < sm_70. -// Earlier shader models not support Tensor Cores. #define CUTLASS_USE_WMMA_API +#if defined(__CUDACC__) && (__CUDACC_VER_MAJOR__ >= 10) && (!defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 750) +#define CUTLASS_USE_SUBBYTE_WMMA +#endif + #include "stdio.h" +#if __CUDACC_VER_MAJOR__ >= 10 +#include +#else #include -#include -#include -#include -#include -#include +#endif +#include "cutlass/fragment.h" +#include "cutlass/matrix_traits.h" +#include "cutlass/shape.h" +#include "cutlass/vector.h" namespace cutlass { @@ -61,6 +64,34 @@ struct WmmaLayout { //////////////////////////////////////////////////////////////////////////////////////////////////// +/// Statically maps cutlass types to nvcuda::wmma datatypes +template +struct WmmaDataType{ + typedef Type_ Type; +}; + +#ifdef CUTLASS_USE_SUBBYTE_WMMA +/// Statically maps cutlass::Vector to nvcuda::wmma::experimental::precision::b1 +template<> +struct WmmaDataType > { + typedef nvcuda::wmma::experimental::precision::b1 Type; +}; + +/// Statically maps cutlass::Vector to nvcuda::wmma::experimental::precision::s4 +template<> +struct WmmaDataType > { + typedef nvcuda::wmma::experimental::precision::s4 Type; +}; + +/// Statically maps cutlass::Vector to nvcuda::wmma::experimental::precision::u4 +template<> +struct WmmaDataType > { + typedef nvcuda::wmma::experimental::precision::u4 Type; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + /// Adapter to nvcuda::wmma fragment load and store operations template WmmaShape_::kH, WmmaShape_::kD, /// The scalar. - Scalar_, + typename WmmaDataType::Type, /// The layout. typename WmmaLayout::Layout> { /// This type. @@ -117,7 +148,7 @@ struct WmmaMatrix WmmaShape_::kH, WmmaShape_::kD, /// The scalar. - Scalar_, + typename WmmaDataType::Type, /// The layout. typename WmmaLayout::Layout> { /// This type. @@ -188,6 +219,18 @@ struct WmmaMatrix //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass +// WmmaMatrix cannot be used in a Union and thus in cannot be used in our Vector implementation. +// The only use of WmmaMatrix in in combination with Vectorize has kLanes == 1. Due to this it is +// safe to keep the Vector->Scalar conversion for WmmaMatrix. +template +struct Vectorize, 1> { + typedef WmmaMatrix Type; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} #endif // defined CUTLASS_USE_WMMA_API diff --git a/cutlass/zip_fragment.h b/cutlass/zip_fragment.h new file mode 100644 index 000000000..37a788614 --- /dev/null +++ b/cutlass/zip_fragment.h @@ -0,0 +1,150 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Models a pair of fragments +*/ +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/shape.h" +#include "cutlass/util/cutlass_math.h" +#include "cutlass/vector.h" + +namespace cutlass { + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/** +* @brief A template defining \ref fragment_concept +* @concept{fragment_concept} +*/ +template +struct ZipFragment { + /// First fragment object + typedef First_ First; + + /// Second fragment object + typedef Second_ Second; + + /// This class. + typedef ZipFragment This_; + + // + // Data members + // + + /// First fragment object + First first; + + /// Second fragment object + Second second; + + // + // Methods + // + + /// Default ctor + CUTLASS_DEVICE + ZipFragment() { } + + /// Copy ctor + CUTLASS_DEVICE + ZipFragment(First const &_first, Second const &_second): first(_first), second(_second) { } + + /// Clear a fragment. + CUTLASS_DEVICE void clear() { + first.clear(); + second.clear(); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to construct a ZipFragment object +template +CUTLASS_HOST_DEVICE +ZipFragment make_ZipFragment(First const &first, Second const &second) { + return ZipFragment(first, second); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Zips two convert operations +template +struct ZipConvert { + /// First convert operator + typedef First_ First; + + /// Second convert operator + typedef Second_ Second; + + /// Defines the input zip fragment + typedef ZipFragment InputFragment; + + /// Defines the output zip fragment + typedef ZipFragment + OutputFragment; + + // + // + // + + /// First transformer + First first; + + /// Second transformer + Second second; + + // + // + // + + /// Ctor. + CUTLASS_DEVICE ZipConvert() {} + + /// Ctor. + CUTLASS_DEVICE ZipConvert(First const &_first, Second const &_second): first(_first), second(_second) { } + + /// Transform a fragment. + CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) { + first.transform(src.first, dst.first); + second.transform(src.second, dst.second); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to construct a ZipConvert object +template +CUTLASS_HOST_DEVICE +ZipConvert make_ZipConvert(First const &first, Second const &second) { + return ZipConvert(first, second); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/cutlass/zip_tensor_ref.h b/cutlass/zip_tensor_ref.h new file mode 100644 index 000000000..d2cff9e0c --- /dev/null +++ b/cutlass/zip_tensor_ref.h @@ -0,0 +1,77 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines a structure containing a pair of TensorRef-like objects +*/ +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/tensor_ref.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ZipTensorRef { + /// First tensor ref + typedef First_ First; + + /// Second tensor ref + typedef Second_ Second; + + // + // Data members + // + + /// First TensorRef + First first; + + /// Second TensorRef + Second second; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ZipTensorRef() {} + + CUTLASS_HOST_DEVICE + ZipTensorRef(First const& _first, Second const& _second) : first(_first), second(_second) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs a ZipTensorRef +template +CUTLASS_HOST_DEVICE +ZipTensorRef make_ZipTensorRef(First const &first, Second const &second) { + return ZipTensorRef(first, second); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass diff --git a/cutlass/zip_tile_iterator.h b/cutlass/zip_tile_iterator.h new file mode 100644 index 000000000..f8ba4eee3 --- /dev/null +++ b/cutlass/zip_tile_iterator.h @@ -0,0 +1,287 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Constructs an iterator that owns two tile iterator instances +*/ + +#pragma once + +#include "cutlass/coord.h" +#include "cutlass/zip_tensor_ref.h" +#include "cutlass/zip_fragment.h" + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Constructs an iterator from a pair of iterators +template +class ZipTileIterator { + public: + /// First iterator type + typedef First_ First; + + /// Second iterator type + typedef Second_ Second; + + /// Params object + struct Params { + /// Parameters of first iterator + typename First::Params first; + + /// Parameters of second iterator + typename Second::Params second; + + /// Constructs a parameters object + CUTLASS_HOST_DEVICE + Params() {} + + /// Constructs a parameters object + CUTLASS_HOST_DEVICE + Params(typename First::Params const &_first, typename Second::Params const &_second) + : first(_first), second(_second) {} + }; + + /// Fragment type + typedef ZipFragment Fragment; + + /// Predicate vector + typedef typename First::PredicateVector PredicateVector; + + /// Index type + typedef typename First::Index Index; + + /// Tensor reference + typedef ZipTensorRef< + typename First::TensorRef, + typename Second::TensorRef> TensorRef; + + // + // Data members + // + + /// First iterator + First first; + + /// Second iterator + Second second; + + // + // Methods + // + + /// Default constructor + CUTLASS_DEVICE + ZipTileIterator() {} + + /// Constructs a zip iterator from params + CUTLASS_DEVICE + ZipTileIterator(Params const &_params, Coord<3> const &threadblock_offset = make_Coord(0, 0, 0)) + : first(_params.first, threadblock_offset), second(_params.second, threadblock_offset) {} + + /// Constructs a zip iterator from iterator instances + CUTLASS_DEVICE + ZipTileIterator(First const &_first, Second const &_second) : first(_first), second(_second) {} + + /// Constructs a zip iterator from iterator instances + CUTLASS_DEVICE + ZipTileIterator(TensorRef const &ref) : first(ref.first), second(ref.second) {} + + /// Constructs a zip iterator from iterator instances + CUTLASS_DEVICE + ZipTileIterator(Params const &_params, TensorRef const &ref): + first(_params.first, ref.first), second(_params.second, ref.second) {} + + // + // Predicate initialization + // + + /// Initializes a predicate vector using a RegularTilePredicateFunctor + template < + /// Predicate iterator + typename PredicateIterator> + CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, + Coord<3> const &bounds, + Coord<3> const &block_offset = make_Coord(0, + 0, + 0)) { + first.initialize_predicates(predicate_it, bounds, block_offset); + } + + /// Initializes a predicate vector using an arbitrary predicate functor + template < + /// Predicate iterator + typename PredicateIterator, + /// Functor computing predicates + typename PredicateFunctor> + CUTLASS_HOST_DEVICE void initialize_predicates(PredicateIterator predicate_it, + PredicateFunctor const &functor, + Coord<3> const &block_offset) { + first.initialize_predicates(predicate_it, functor, block_offset); + } + + // + // No predicates + // + + /// Loads a fragment and increments without predicates + template + CUTLASS_DEVICE void load_post_increment(Fragment &fragment) { + first.load_post_increment(fragment.first); + second.load_post_increment(fragment.second); + } + + /// Loads a fragment and increments without predicates + template + CUTLASS_DEVICE void load_post_increment(Fragment &fragment, + Coord<4> const &offset) { + first.load_post_increment(fragment.first, offset); + second.load_post_increment(fragment.second, offset); + } + + /// Loads a fragment without predicates + template + CUTLASS_DEVICE void load(Fragment &fragment) const { + first.load(fragment.first); + second.load(fragment.second); + } + + /// Loads a fragment without predicates + template + CUTLASS_DEVICE void load(Fragment &fragment, + Coord<4> const &offset) const { + first.load(fragment.first, offset); + second.load(fragment.second, offset); + } + + /// Stores a fragment and increments without predicates + template + CUTLASS_DEVICE void store_post_increment(Fragment const &fragment) { + first.store_post_increment(fragment.first); + second.store_post_increment(fragment.second); + } + + /// Stores a fragment and increments without predicates + template + CUTLASS_DEVICE void store_post_increment(Fragment const &fragment, + Coord<4> const &offset) { + first.store_post_increment(fragment.first, offset); + second.store_post_increment(fragment.second, offset); + } + + /// Stores a fragment without predicates + template + CUTLASS_DEVICE void store(Fragment const &fragment) const { + first.store(fragment.first); + second.store(fragment.second); + } + + /// Stores a fragment without predicates + template + CUTLASS_DEVICE void store(Fragment const &fragment, + Coord<4> const &offset) const { + first.store(fragment.first, offset); + second.store(fragment.second, offset); + } + + // + // With predication + // + + /// Loads a fragment and increments, using predicates + template + CUTLASS_DEVICE void load_post_increment(Fragment &fragment, PredicateIterator pred_it) { + first.load_post_increment(fragment.first, pred_it); + second.load_post_increment(fragment.second, pred_it); + } + + /// Loads a fragment with predicates + template + CUTLASS_DEVICE void load(Fragment &fragment, PredicateIterator pred_it) const { + first.load(fragment.first, pred_it); + second.load(fragment.second, pred_it); + } + + /// Loads a fragment and increments, using predicates + template + CUTLASS_DEVICE void store_post_increment(Fragment const &fragment, PredicateIterator pred_it) { + first.store_post_increment(fragment.first, pred_it); + second.store_post_increment(fragment.second, pred_it); + } + + /// Loads a fragment with predicates + template + CUTLASS_DEVICE void store(Fragment const &fragment, PredicateIterator pred_it) const { + first.store(fragment.first, pred_it); + second.store(fragment.second, pred_it); + } + + // + // Advances the iterators + // + + /// Increments store iterator to next tile + CUTLASS_DEVICE ZipTileIterator &increment(int count = 1) { + first.increment(count); + second.increment(count); + return *this; + } + + /// Increments to next tile + CUTLASS_DEVICE ZipTileIterator &operator++() { return increment(); } + + CUTLASS_DEVICE ZipTileIterator &operator+=(int count) { return increment(count); } + + /// Adds a vector offset to the underlying iterators + CUTLASS_DEVICE ZipTileIterator &operator+=(Coord<3> const &offset) { + first += offset; + second += offset; + return *this; + } + + /// Increments store iterator to previous tile + CUTLASS_DEVICE ZipTileIterator &decrement(int count = 1) { + first.decrement(count); + second.decrement(count); + return *this; + } + + /// Increments to subsequent tile + CUTLASS_DEVICE ZipTileIterator &operator--() { return decrement(); } + + /// Decrements to previous tile + CUTLASS_DEVICE ZipTileIterator &operator-=(int count) { return decrement(count); } + + /// Adds an offset to both iterators + CUTLASS_DEVICE void add_pointer_offset(Index offset) { + first.add_pointer_offset(offset); + second.add_pointer_offset(offset); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namspace cutlass diff --git a/docs/annotated.html b/docs/annotated.html index e6c405d59..da54a8ee0 100644 --- a/docs/annotated.html +++ b/docs/annotated.html @@ -74,303 +74,368 @@ $(function() {
Here are the classes, structs, unions and interfaces with brief descriptions:
[detail level 1234]
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
 Ncutlass
 Ngemm
 Nplatform
 CAlignedStruct
 CComputeOffsetFromShapeCompute the offset for the given coordinates in a cube
 CComputeOffsetFromShape< Shape< 1, kSh_, kSw_, 1 > >Compute the offset for the given coordinates in a cube with one channel and a depth of 1
 CComputeOffsetFromShape< Shape< 1, kSh_, kSw_, kSc_ > >Compute the offset for the given coordinates in a cube with a depth of 1
 CComputeOffsetFromStridesCompute the offset for the given coordinates in a cube
 CComputeOffsetFromStrides< Shape< 1, S_h_, S_w_, 1 > >Compute the offset for the given coordinates in a cube with one channel and a depth of 1
 CComputeOffsetFromStrides< Shape< 1, S_h_, S_w_, S_c_ > >Compute the offset for the given coordinates in a cube with a depth of 1
 CComputeThreadOffsetFromStridesDecompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_. Afterwards compute the offset of those coordinates using Strides_
 CComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, 1 >, Shape< 1, S_h_, S_w_, 1 > >Specialization for D=1 and C=1
 CComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, T_c_ >, Shape< 1, S_h_, S_w_, S_c_ > >Specialization for D=1
 CConstPredicateTileAdapterAdapter to enable random access to predicates via logical coordinate within a tile
 CConvert
 CConvert< Fragment< InputScalar_, kScalars_ >, Fragment< OutputScalar_, kScalars_ > >
 CCoordStatically-sized array specifying Coords within a tensor
 CCopy
 Cdivide_assert
 CExtentReturns the extent of a scalar or vector
 CExtent< Vector< T, Lanes > >Returns the number of lanes of a vector if need be
 CExtent< Vector< T, Lanes > const >Returns the number of lanes of a vector if need be
 CFragmentA template defining Fragment Concept
 CFragmentConstIterator
 CFragmentIteratorA template defining Fragment Iterator Concept
 CFragmentLoad
 CFragmentLoad< IteratorFragment::kScalar, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
 CFragmentLoad< IteratorFragment::kWmmaMatrix, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
 CFragmentStore
 CFragmentStore< IteratorFragment::kScalar, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
 CFragmentStore< IteratorFragment::kWmmaMatrix, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
 CGemmOperandGemm operand - D = A * B + C
 CIdentityDescribes identity elements
 Cis_pow2
 CIteratorAdvanceSpecifies dimension in which post-increment accesses advance
 CIteratorFragmentSpecifies whether iterator storage fragment consists of Scalar values or WMMA matrix
 CLoad
 CLoad< double, 2, Memory_, true, 16 >
 CLoad< Scalar_, Lanes_, Memory_, true, 16 >
 CLoad< Scalar_, Lanes_, Memory_, true, 4 >
 CLoad< Scalar_, Lanes_, Memory_, true, 8 >
 Clog2_down
 Clog2_down< N, 1, Count >
 Clog2_up
 Clog2_up< N, 1, Count >
 CMatrixLayoutDescribes layouts of matrices
 CMemorySpaceEnum to specify which memory space data resides in
 CPredicateTileAdapterAdapter to enable random access to predicates via logical coordinate within a tile
 CPredicateVectorStatically sized array of bits implementing
 CReshapeTile
 CReshapeTile< Tile_, kAccessSize_, true >
 CShapeA Shape implementing Layout Concept describing the dimensions of a cube
 CShapeAdd
 CShapeCountCompute derived counted of a Layout Concept based class
 CShapeDiv
 CShapeMax
 CShapeMin
 CShapeMul
 CShapeScale
 CShapeStrides
 CShapeSub
 Csqrt_est
 CStorageType
 CStorageType< 1 >
 CStorageType< 2 >
 CStorageType< 4 >
 CStore
 CStore< double, 2, Memory_, true, 16 >
 CStore< Scalar_, Lanes_, Memory_, true, 16 >
 CStore< Scalar_, Lanes_, Memory_, true, 4 >
 CStore< Scalar_, Lanes_, Memory_, true, 8 >
 CTensorRefStructure modeling a pointer and stride into a tensor
 CTensorViewHost-side reference implementation of tensor operations
 CTiledThreadOffsetBasic thread offset function computed from a thread shape
 CTileIteratorBaseIterator for accessing a stripmined tile in memory
 CTileLoadIteratorAn iterator implementing Tile Load Iterator Concept for loading a tile from memory
 CTileStoreIteratorAn iterator implementing Tile Store Iterator Concept for storing a tile to memory
 CTileTraitsA template defining Tile Traits Concept
 CTileTraitsContiguousMajor
 CTileTraitsStandardChooses 'best' shape to enable warp raking along contiguous dimension if possible
 CTileTraitsStrideMajor
 CTileTraitsWarpRakeTiling in which warps rake across the contiguous dimension
 CTrivialPredicateTileAdapterAlways returns true predicate
 CVector
 CVector< half, kLanes_ >
 CVectorize
 CVectorize< Element_, 1 >
 CVectorTraitsTraits describing properties of vectors and scalar-as-vectors
 CVectorTraits< Vector< T, Lanes > >Partial specialization for actual cutlass::Vector
 CVectorTraits< Vector< T, Lanes > const >Partial specialization for actual cutlass::Vector
 Ncutlass
 CDebugType
 CDebugValue
diff --git a/docs/classcutlass_1_1PredicateVector_1_1ConstIterator-members.html b/docs/classcutlass_1_1PredicateVector_1_1ConstIterator-members.html index 860cd05cb..18f59fc0c 100644 --- a/docs/classcutlass_1_1PredicateVector_1_1ConstIterator-members.html +++ b/docs/classcutlass_1_1PredicateVector_1_1ConstIterator-members.html @@ -91,7 +91,7 @@ $(function() { diff --git a/docs/classcutlass_1_1PredicateVector_1_1ConstIterator.html b/docs/classcutlass_1_1PredicateVector_1_1ConstIterator.html index 1fbdc759c..7e7089a06 100644 --- a/docs/classcutlass_1_1PredicateVector_1_1ConstIterator.html +++ b/docs/classcutlass_1_1PredicateVector_1_1ConstIterator.html @@ -381,7 +381,7 @@ template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ diff --git a/docs/classcutlass_1_1PredicateVector_1_1Iterator-members.html b/docs/classcutlass_1_1PredicateVector_1_1Iterator-members.html index ca3ff04aa..73d0ebcaa 100644 --- a/docs/classcutlass_1_1PredicateVector_1_1Iterator-members.html +++ b/docs/classcutlass_1_1PredicateVector_1_1Iterator-members.html @@ -93,7 +93,7 @@ $(function() { diff --git a/docs/classcutlass_1_1PredicateVector_1_1Iterator.html b/docs/classcutlass_1_1PredicateVector_1_1Iterator.html index 42a069382..2cbc797d8 100644 --- a/docs/classcutlass_1_1PredicateVector_1_1Iterator.html +++ b/docs/classcutlass_1_1PredicateVector_1_1Iterator.html @@ -443,7 +443,7 @@ template<int kPredicates_, int kPredicatesPerByte_ = 4, int kPredicateStart_ diff --git a/docs/classcutlass_1_1TensorRef-members.html b/docs/classcutlass_1_1TensorRef-members.html index 4bf37ad13..202c9ab42 100644 --- a/docs/classcutlass_1_1TensorRef-members.html +++ b/docs/classcutlass_1_1TensorRef-members.html @@ -73,35 +73,52 @@ $(function() {
-
cutlass::TensorRef< Storage_, Rank_ > Member List
+
cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > Member List
-

This is the complete list of members for cutlass::TensorRef< Storage_, Rank_ >, including all inherited members.

+

This is the complete list of members for cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >, including all inherited members.

- - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
advance(Coord< Rank > const &b)cutlass::TensorRef< Storage_, Rank_ >inline
at(Coord< Rank > const &coord) constcutlass::TensorRef< Storage_, Rank_ >inline
at(int idx) constcutlass::TensorRef< Storage_, Rank_ >inline
convert()cutlass::TensorRef< Storage_, Rank_ >inline
data() constcutlass::TensorRef< Storage_, Rank_ >inline
good() constcutlass::TensorRef< Storage_, Rank_ >inline
leading_dim() constcutlass::TensorRef< Storage_, Rank_ >inline
offset(Coord< Rank > const &coord) constcutlass::TensorRef< Storage_, Rank_ >inline
operator+(Coord< Rank > const &b) constcutlass::TensorRef< Storage_, Rank_ >inline
operator-(Coord< Rank > const &b) constcutlass::TensorRef< Storage_, Rank_ >inline
operator[](Coord< Rank > const &coord) constcutlass::TensorRef< Storage_, Rank_ >inline
operator[](int idx) constcutlass::TensorRef< Storage_, Rank_ >inline
Rankcutlass::TensorRef< Storage_, Rank_ >static
reset(Storage *ptr=nullptr, Coord< Rank > stride=Coord< Rank >(0))cutlass::TensorRef< Storage_, Rank_ >inline
Storage typedefcutlass::TensorRef< Storage_, Rank_ >
stride() constcutlass::TensorRef< Storage_, Rank_ >inline
stride(int dim) constcutlass::TensorRef< Storage_, Rank_ >inline
TensorRef()cutlass::TensorRef< Storage_, Rank_ >inline
TensorRef(Storage *ptr, Coord< Rank > stride)cutlass::TensorRef< Storage_, Rank_ >inline
add_pointer_offset(LongIndex delta)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
at(TensorCoord const &coord) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
at(LongIndex idx) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
const_ref() constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
ConstTensorRef typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
Coord_t typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
data() constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
good() constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
Index typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
kRankcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >static
kStorageRankcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >static
leading_dim(int idx=0) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
LongIndex typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
map(TensorCoord const &coord) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
MapFunc typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
offset(TensorCoord const &coord) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
operator+(TensorCoord const &b) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
operator+=(TensorCoord const &b)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
operator-(TensorCoord const &b) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
operator-=(TensorCoord const &b)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
operator[](TensorCoord const &coord) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
operator[](LongIndex idx) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
Rankcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >static
reset(Storage *ptr=nullptr)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
reset(Storage *ptr, StorageCoord const &stride)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
Storage typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
StorageCoord typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
stride() constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
stride(int dim) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
StrideVector typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
TensorCoord typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
TensorRef(Storage *ptr=nullptr)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
TensorRef(Storage *ptr, Index ldm)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
TensorRef(Storage *ptr, StrideVector const &stride)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
TensorRef(Storage *ptr, StorageCoord const &stride)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
TensorRef(TensorRef< typename platform::remove_const< Storage >::type, kRank, MapFunc, kStorageRank, Index, LongIndex > const &ref)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
diff --git a/docs/classcutlass_1_1TensorRef.html b/docs/classcutlass_1_1TensorRef.html index 05a9b3dd5..1053ca0a9 100644 --- a/docs/classcutlass_1_1TensorRef.html +++ b/docs/classcutlass_1_1TensorRef.html @@ -5,7 +5,7 @@ -Cutlass: cutlass::TensorRef< Storage_, Rank_ > Class Template Reference +Cutlass: cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > Class Template Reference @@ -78,93 +78,278 @@ $(function() { Static Public Attributes | List of all members
-
cutlass::TensorRef< Storage_, Rank_ > Class Template Reference
+
cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > Class Template Reference
-

Structure modeling a pointer and stride into a tensor. -

-

#include <tensor_ref.h>

+
+Inheritance diagram for cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >:
+
+
+ + +cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > +cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > + +
- - - + + + + + + + + + + + + + + + + + + + + + + + + + +

Public Types

typedef Storage_ Storage
 Data type of individual access. More...
 
typedef Storage_ Storage
 Data type of individual access. More...
 
typedef MapFunc_ MapFunc
 Mapping function from logical coordinate to internal n-D array. More...
 
typedef Index_ Index
 Index type. More...
 
typedef LongIndex_ LongIndex
 Typically, strides in memory can be very large. More...
 
typedef Coord< kRankTensorCoord
 Coordinate in logical tensor space. More...
 
typedef Coord< kStorageRankStorageCoord
 Coordinate in storage n-D array. More...
 
typedef Coord< kStorageRank - 1 > StrideVector
 
typedef TensorRef< typename platform::remove_const< Storage >::type const, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > ConstTensorRef
 Tensor reference to of constant value. More...
 
typedef TensorCoord Coord_t
 Coordinate in logical tensor space. More...
 
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

Public Member Functions

CUTLASS_HOST_DEVICE TensorRef ()
 Default ctor. More...
 
CUTLASS_HOST_DEVICE TensorRef (Storage *ptr, Coord< Rank > stride)
 Constructs from a pointer, size, and stride. More...
 
CUTLASS_HOST_DEVICE void reset (Storage *ptr=nullptr, Coord< Rank > stride=Coord< Rank >(0))
 Updates the pointer, stride, and location within a TensorRef. More...
 
template<typename T >
TensorRef< T, Rankconvert ()
 Conversion function. More...
 
CUTLASS_HOST_DEVICE bool good () const
 Returns true if the TensorRef may be safely accessed. More...
 
CUTLASS_HOST_DEVICE Storagedata () const
 Returns the pointer to referenced data. More...
 
CUTLASS_HOST_DEVICE Coord< Rank > const & stride () const
 Returns the stride of the tensor. More...
 
CUTLASS_HOST_DEVICE int const & stride (int dim) const
 Returns the stride of the tensor in the given dimension. More...
 
CUTLASS_HOST_DEVICE int leading_dim () const
 Returns the maximum stride element as the 'leading dimension'. More...
 
CUTLASS_HOST_DEVICE long long offset (Coord< Rank > const &coord) const
 Computes the offset of an index from the origin of the tensor. More...
 
CUTLASS_HOST_DEVICE Storageat (Coord< Rank > const &coord) const
 Returns a reference to the element at a given Coord. More...
 
Storageoperator[] (Coord< Rank > const &coord) const
 Element-wise accessor. More...
 
CUTLASS_HOST_DEVICE Storageat (int idx) const
 Returns a reference to the element at a given Coord. More...
 
Storageoperator[] (int idx) const
 Element-wise accessor. More...
 
CUTLASS_HOST_DEVICE TensorRefadvance (Coord< Rank > const &b)
 Adds an offset to the pointer. More...
 
CUTLASS_HOST_DEVICE TensorRef operator+ (Coord< Rank > const &b) const
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorRef operator- (Coord< Rank > const &b) const
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorRef (Storage *ptr=nullptr)
 Helper for 1-D memory. All higher ranks are projected onto the fastest changing rank. More...
 
CUTLASS_HOST_DEVICE TensorRef (Storage *ptr, Index ldm)
 Helper to construct from a pointer and single stride element for 2-D pitch linear memory. More...
 
CUTLASS_HOST_DEVICE TensorRef (Storage *ptr, StrideVector const &stride)
 Constructs from a single pointer and stride vector. More...
 
CUTLASS_HOST_DEVICE TensorRef (Storage *ptr, StorageCoord const &stride)
 
CUTLASS_HOST_DEVICE TensorRef (TensorRef< typename platform::remove_const< Storage >::type, kRank, MapFunc, kStorageRank, Index, LongIndex > const &ref)
 Enables conversion from TensorRef of non-const type. More...
 
CUTLASS_HOST_DEVICE ConstTensorRef const_ref () const
 Returns a reference to constant-valued tensor. More...
 
CUTLASS_HOST_DEVICE void reset (Storage *ptr=nullptr)
 Updates only the pointer. More...
 
CUTLASS_HOST_DEVICE void reset (Storage *ptr, StorageCoord const &stride)
 Updates the pointer, stride, and location within a TensorRef. More...
 
CUTLASS_HOST_DEVICE bool good () const
 Returns true if the TensorRef may be safely accessed. More...
 
CUTLASS_HOST_DEVICE Storagedata () const
 Returns the pointer to referenced data. More...
 
CUTLASS_HOST_DEVICE StorageCoord stride () const
 Returns the stride of the tensor. More...
 
CUTLASS_HOST_DEVICE Index stride (int dim) const
 Returns the stride of the tensor in the given dimension. More...
 
CUTLASS_HOST_DEVICE Index leading_dim (int idx=0) const
 Returns the maximum stride element as the 'leading dimension'. More...
 
CUTLASS_HOST_DEVICE StorageCoord map (TensorCoord const &coord) const
 Maps a logical coordinate to an n-D array in memory. More...
 
CUTLASS_HOST_DEVICE LongIndex offset (TensorCoord const &coord) const
 Computes the offset of an index from the origin of the tensor. More...
 
CUTLASS_HOST_DEVICE Storageat (TensorCoord const &coord) const
 Returns a reference to the element at a given Coord. More...
 
CUTLASS_HOST_DEVICE Storageat (LongIndex idx) const
 Returns a reference to the element at a given linear index. More...
 
CUTLASS_HOST_DEVICE Storageoperator[] (TensorCoord const &coord) const
 Returns a reference to the element at a given Coord. More...
 
CUTLASS_HOST_DEVICE Storageoperator[] (LongIndex idx) const
 Returns a reference to the element at a given linear index. More...
 
CUTLASS_HOST_DEVICE TensorRefadd_pointer_offset (LongIndex delta)
 Adds an offset to each pointer. More...
 
CUTLASS_HOST_DEVICE TensorRef operator+ (TensorCoord const &b) const
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorRefoperator+= (TensorCoord const &b)
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorRef operator- (TensorCoord const &b) const
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorRefoperator-= (TensorCoord const &b)
 Returns a TensorRef offset by a given amount. More...
 
- - - + + + + + + + + +

Static Public Attributes

static int const Rank = Rank_
 Rank of tensor. More...
 
static int const kRank = Rank_
 Logical rank of tensor index space. More...
 
static int const kStorageRank = StorageRank_
 Rank of internal storage. More...
 
static int const Rank = kRank
 Logical rank of tensor index space. More...
 

Member Typedef Documentation

- -

◆ Storage

+ +

◆ ConstTensorRef

-template<typename Storage_, int Rank_>
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
- + + +
typedef Storage_ cutlass::TensorRef< Storage_, Rank_ >::Storagetypedef TensorRef< typename platform::remove_const<Storage>::type const, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstTensorRef
+
+ +
+
+ +

◆ Coord_t

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef TensorCoord cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::Coord_t
+
+

Require at least rank=1. Mathematically, a rank=0 tensor would be considered to be a scalar, but degenerate cases such as these are difficult to accommodate without extensive C++ metaprogramming or support for zero-length arrays.

+ +
+
+ +

◆ Index

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef Index_ cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::Index
+
+ +
+
+ +

◆ LongIndex

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef LongIndex_ cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::LongIndex
+
+ +
+
+ +

◆ MapFunc

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef MapFunc_ cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::MapFunc
+
+ +
+
+ +

◆ Storage

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef Storage_ cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::Storage
+
+ +
+
+ +

◆ StorageCoord

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef Coord<kStorageRank> cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::StorageCoord
+
+ +
+
+ +

◆ StrideVector

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef Coord<kStorageRank - 1> cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::StrideVector
+
+

Stride vector in storage coordinage space - assumes least significant stride is 1 and does not store it.

+ +
+
+ +

◆ TensorCoord

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + +
typedef Coord<kRank> cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorCoord
@@ -172,21 +357,22 @@ template<typename Storage_, int Rank_>

Constructor & Destructor Documentation

- -

◆ TensorRef() [1/2]

+ +

◆ TensorRef() [1/5]

-template<typename Storage_, int Rank_>
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
diff --git a/docs/hgemm__global__tile_8h.html b/docs/hgemm__global__tile_8h.html index b62b8c143..0b2e247f5 100644 --- a/docs/hgemm__global__tile_8h.html +++ b/docs/hgemm__global__tile_8h.html @@ -82,10 +82,10 @@ $(function() {

Tile traits used to construct global tile iterator for HGEMM. This is intended to partition the thread block-level tile into 2D subtiles loaded by the threads and facilitate memory accesses larger than 16 bits. More...

-
- + - + +
CUTLASS_HOST_DEVICE cutlass::TensorRef< Storage_, Rank_ >::TensorRef CUTLASS_HOST_DEVICE cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorRef ()Storageptr = nullptr)
@@ -199,27 +385,65 @@ template<typename Storage_, int Rank_> - -

◆ TensorRef() [2/2]

+ +

◆ TensorRef() [2/5]

-template<typename Storage_, int Rank_>
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + +
- + - + - + + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::TensorRef< Storage_, Rank_ >::TensorRef CUTLASS_HOST_DEVICE cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorRef (StorageStorage ptr,
Coord< RankIndex ldm 
)
+
+inline
+
+ +
+ + +

◆ TensorRef() [3/5]

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + diff --git a/docs/group__tile__traits__concept.html b/docs/group__tile__traits__concept.html index 16e4bd8ae..6c0516967 100644 --- a/docs/group__tile__traits__concept.html +++ b/docs/group__tile__traits__concept.html @@ -77,7 +77,7 @@ $(function() {
+ + + + + + + + + + + @@ -237,363 +461,541 @@ template<typename Storage_, int Rank_> -

Member Function Documentation

- -

◆ advance()

+ +

◆ TensorRef() [4/5]

-template<typename Storage_, int Rank_>
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
CUTLASS_HOST_DEVICE cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorRef (Storageptr,
StrideVector const &  stride 
- - -
- + - - - - -
CUTLASS_HOST_DEVICE TensorRef& cutlass::TensorRef< Storage_, Rank_ >::advance CUTLASS_HOST_DEVICE cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorRef (Coord< Rank > const & b)
-
-inline
-
- -
- - -

◆ at() [1/2]

- -
-
-
-template<typename Storage_, int Rank_>
- - - - - -
- - - - - - - - -
CUTLASS_HOST_DEVICE Storage& cutlass::TensorRef< Storage_, Rank_ >::at (Coord< Rank > const & coord) const
-
-inline
-
- -
-
- -

◆ at() [2/2]

- -
-
-
-template<typename Storage_, int Rank_>
- - - - - -
- - - - - - - - -
CUTLASS_HOST_DEVICE Storage& cutlass::TensorRef< Storage_, Rank_ >::at (int idx) const
-
-inline
-
- -
-
- -

◆ convert()

- -
-
-
-template<typename Storage_, int Rank_>
-
-template<typename T >
- - - - - -
- - - - - - - -
TensorRef<T, Rank> cutlass::TensorRef< Storage_, Rank_ >::convert ()
-
-inline
-
- -
-
- -

◆ data()

- -
-
-
-template<typename Storage_, int Rank_>
- - - - - -
- - - - - - - -
CUTLASS_HOST_DEVICE Storage* cutlass::TensorRef< Storage_, Rank_ >::data () const
-
-inline
-
- -
-
- -

◆ good()

- -
-
-
-template<typename Storage_, int Rank_>
- - - - - -
- - - - - - - -
CUTLASS_HOST_DEVICE bool cutlass::TensorRef< Storage_, Rank_ >::good () const
-
-inline
-
- -
-
- -

◆ leading_dim()

- -
-
-
-template<typename Storage_, int Rank_>
- - - - - -
- - - - - - - -
CUTLASS_HOST_DEVICE int cutlass::TensorRef< Storage_, Rank_ >::leading_dim () const
-
-inline
-
- -
-
- -

◆ offset()

- -
-
-
-template<typename Storage_, int Rank_>
- - - - - -
- - - - - - - - -
CUTLASS_HOST_DEVICE long long cutlass::TensorRef< Storage_, Rank_ >::offset (Coord< Rank > const & coord) const
-
-inline
-
- -
-
- -

◆ operator+()

- -
-
-
-template<typename Storage_, int Rank_>
- - - - - -
- - - - - - - - -
CUTLASS_HOST_DEVICE TensorRef cutlass::TensorRef< Storage_, Rank_ >::operator+ (Coord< Rank > const & b) const
-
-inline
-
- -
-
- -

◆ operator-()

- -
-
-
-template<typename Storage_, int Rank_>
- - - - - -
- - - - - - - - -
CUTLASS_HOST_DEVICE TensorRef cutlass::TensorRef< Storage_, Rank_ >::operator- (Coord< Rank > const & b) const
-
-inline
-
- -
-
- -

◆ operator[]() [1/2]

- -
-
-
-template<typename Storage_, int Rank_>
- - - - - -
- - - - - - - - -
Storage& cutlass::TensorRef< Storage_, Rank_ >::operator[] (Coord< Rank > const & coord) const
-
-inline
-
- -
-
- -

◆ operator[]() [2/2]

- -
-
-
-template<typename Storage_, int Rank_>
- - - - - -
- - - - - - - - -
Storage& cutlass::TensorRef< Storage_, Rank_ >::operator[] (int idx) const
-
-inline
-
- -
-
- -

◆ reset()

- -
-
-
-template<typename Storage_, int Rank_>
- - - + + +
- - - - - - + + - - + + + + + + + + +
CUTLASS_HOST_DEVICE void cutlass::TensorRef< Storage_, Rank_ >::reset (Storageptr = nullptr, Storageptr,
Coord< Rankstride = Coord<Rank>(0) StorageCoord const & stride 
)
+
+inline
+
+

Constructs from a pointer and a stride vector of size kRank. If fastest changing stride is not 1, construction fails and subsequent calls to good() will return false.

+ +
+
+ +

◆ TensorRef() [5/5]

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorRef (TensorRef< typename platform::remove_const< Storage >::type, kRank, MapFunc, kStorageRank, Index, LongIndex > const & ref)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ add_pointer_offset()

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorRef& cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::add_pointer_offset (LongIndex delta)
+
+inline
+
+ +
+
+ +

◆ at() [1/2]

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE Storage& cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::at (TensorCoord const & coord) const
+
+inline
+
+ +
+
+ +

◆ at() [2/2]

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE Storage& cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::at (LongIndex idx) const
+
+inline
+
+ +
+
+ +

◆ const_ref()

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE ConstTensorRef cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::const_ref () const
+
+inline
+
+ +
+
+ +

◆ data()

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE Storage* cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::data () const
+
+inline
+
+ +
+
+ +

◆ good()

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE bool cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::good () const
+
+inline
+
+ +
+
+ +

◆ leading_dim()

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE Index cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::leading_dim (int idx = 0) const
+
+inline
+
+ +
+
+ +

◆ map()

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE StorageCoord cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::map (TensorCoord const & coord) const
+
+inline
+
+ +
+
+ +

◆ offset()

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE LongIndex cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::offset (TensorCoord const & coord) const
+
+inline
+
+ +
+
+ +

◆ operator+()

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorRef cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::operator+ (TensorCoord const & b) const
+
+inline
+
+ +
+
+ +

◆ operator+=()

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorRef& cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::operator+= (TensorCoord const & b)
+
+inline
+
+ +
+
+ +

◆ operator-()

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorRef cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::operator- (TensorCoord const & b) const
+
+inline
+
+ +
+
+ +

◆ operator-=()

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorRef& cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::operator-= (TensorCoord const & b)
+
+inline
+
+ +
+
+ +

◆ operator[]() [1/2]

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE Storage& cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::operator[] (TensorCoord const & coord) const
+
+inline
+
+ +
+
+ +

◆ operator[]() [2/2]

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE Storage& cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::operator[] (LongIndex idx) const
+
+inline
+
+ +
+
+ +

◆ reset() [1/2]

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE void cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::reset (Storageptr = nullptr)
+
+inline
+
+ +
+
+ +

◆ reset() [2/2]

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + diff --git a/docs/group__tile__store__iterator__concept.html b/docs/group__tile__store__iterator__concept.html index bde540531..992a7ca39 100644 --- a/docs/group__tile__store__iterator__concept.html +++ b/docs/group__tile__store__iterator__concept.html @@ -77,7 +77,7 @@ $(function() {
+ + + + + + + + + + + + @@ -610,19 +1012,19 @@ template<typename Storage_, int Rank_> - -

◆ stride() [1/2]

+ +

◆ stride() [1/2]

-template<typename Storage_, int Rank_>
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
CUTLASS_HOST_DEVICE void cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::reset (Storageptr,
StorageCoord const & stride 
diff --git a/docs/group__fragment__iterator__concept.html b/docs/group__fragment__iterator__concept.html index dc89e72e5..e08d36e44 100644 --- a/docs/group__fragment__iterator__concept.html +++ b/docs/group__fragment__iterator__concept.html @@ -91,7 +91,7 @@ Classes diff --git a/docs/group__layout__concept.html b/docs/group__layout__concept.html index 3fe8532c8..66a828819 100644 --- a/docs/group__layout__concept.html +++ b/docs/group__layout__concept.html @@ -100,7 +100,7 @@ Classes diff --git a/docs/group__predicate__iterator__concept.html b/docs/group__predicate__iterator__concept.html index 95c1ef2ef..9c3b71084 100644 --- a/docs/group__predicate__iterator__concept.html +++ b/docs/group__predicate__iterator__concept.html @@ -98,7 +98,7 @@ Classes diff --git a/docs/group__predicate__tile__adapter.html b/docs/group__predicate__tile__adapter.html index a4b809922..8ab28fed9 100644 --- a/docs/group__predicate__tile__adapter.html +++ b/docs/group__predicate__tile__adapter.html @@ -80,7 +80,7 @@ $(function() { diff --git a/docs/group__predicate__vector__concept.html b/docs/group__predicate__vector__concept.html index 5147870e6..cf4fd5b2a 100644 --- a/docs/group__predicate__vector__concept.html +++ b/docs/group__predicate__vector__concept.html @@ -92,7 +92,7 @@ Classes diff --git a/docs/group__tile__load__iterator__concept.html b/docs/group__tile__load__iterator__concept.html index 2bc4b4e34..edc492818 100644 --- a/docs/group__tile__load__iterator__concept.html +++ b/docs/group__tile__load__iterator__concept.html @@ -77,7 +77,7 @@ $(function() {
- + @@ -637,19 +1039,19 @@ template<typename Storage_, int Rank_> - -

◆ stride() [2/2]

+ +

◆ stride() [2/2]

-template<typename Storage_, int Rank_>
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
CUTLASS_HOST_DEVICE Coord<Rank> const& cutlass::TensorRef< Storage_, Rank_ >::stride CUTLASS_HOST_DEVICE StorageCoord cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::stride ( ) const
+ + +
- + @@ -666,19 +1068,67 @@ template<typename Storage_, int Rank_>

Member Data Documentation

- -

◆ Rank

+ +

◆ kRank

-template<typename Storage_, int Rank_>
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
CUTLASS_HOST_DEVICE int const& cutlass::TensorRef< Storage_, Rank_ >::stride CUTLASS_HOST_DEVICE Index cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::stride ( int  dim)
+ + +
- + + +
int const cutlass::TensorRef< Storage_, Rank_ >::Rank = Rank_int const cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::kRank = Rank_
+
+static
+
+ +
+ + +

◆ kStorageRank

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + +
int const cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::kStorageRank = StorageRank_
+
+static
+
+ +
+
+ +

◆ Rank

+ +
+
+
+template<typename Storage_, int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + @@ -696,7 +1146,7 @@ template<typename Storage_, int Rank_> diff --git a/docs/classcutlass_1_1TensorRef.png b/docs/classcutlass_1_1TensorRef.png new file mode 100644 index 000000000..f8caaa61d Binary files /dev/null and b/docs/classcutlass_1_1TensorRef.png differ diff --git a/docs/classcutlass_1_1TensorRefArray_1_1ConstIterator-members.html b/docs/classcutlass_1_1TensorRefArray_1_1ConstIterator-members.html new file mode 100644 index 000000000..44c118956 --- /dev/null +++ b/docs/classcutlass_1_1TensorRefArray_1_1ConstIterator-members.html @@ -0,0 +1,101 @@ + + + + + + + +Cutlass: Member List + + + + + + + + + + +
+
+
+ + +
int const cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::Rank = kRank
+ + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+
+
cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator Member List
+
+
+ +

This is the complete list of members for cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator, including all inherited members.

+ + + + + + + + + + + + +
ConstIterator(TensorArrayRef const &ref, int idx=0)cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator() constcutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator+(Index idx)cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator++()cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator++(int)cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator+=(Index idx)cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator-(Index idx)cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator--()cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator--(int)cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator-=(Index idx)cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
TensorRef typedefcutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator
+ + + + diff --git a/docs/classcutlass_1_1TensorRefArray_1_1ConstIterator.html b/docs/classcutlass_1_1TensorRefArray_1_1ConstIterator.html new file mode 100644 index 000000000..aa40085cb --- /dev/null +++ b/docs/classcutlass_1_1TensorRefArray_1_1ConstIterator.html @@ -0,0 +1,440 @@ + + + + + + + +Cutlass: cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator Class Reference + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+ +
+
cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator Class Reference
+
+
+ +

TensorRefIterator over TensorRef objects in TensorRefArray. +

+ +

#include <tensor_ref_collection.h>

+ + + + + +

+Public Types

typedef Base TensorRef
 TensorRef returned by the iterator. More...
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

CUTLASS_HOST_DEVICE ConstIterator (TensorArrayRef const &ref, int idx=0)
 Constructs a ConstIterator over the TensorRef objects. More...
 
CUTLASS_HOST_DEVICE TensorRefoperator () const
 Obtains a TensorRef pointed to by this iterator. More...
 
CUTLASS_HOST_DEVICE ConstIteratoroperator++ ()
 Advances to next TensorRef. More...
 
CUTLASS_HOST_DEVICE ConstIterator operator++ (int)
 Advances to next TensorRef. More...
 
CUTLASS_HOST_DEVICE ConstIterator operator+ (Index idx)
 
CUTLASS_HOST_DEVICE ConstIteratoroperator+= (Index idx)
 
CUTLASS_HOST_DEVICE ConstIteratoroperator-- ()
 
CUTLASS_HOST_DEVICE ConstIterator operator-- (int)
 Advances to next TensorRef. More...
 
CUTLASS_HOST_DEVICE ConstIteratoroperator-= (Index idx)
 
CUTLASS_HOST_DEVICE ConstIterator operator- (Index idx)
 
+

Member Typedef Documentation

+ +

◆ TensorRef

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef Base cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::TensorRef
+
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ ConstIterator()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::ConstIterator (TensorArrayRef const & ref,
int idx = 0 
)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ operator()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE TensorRef* cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator () const
+
+inline
+
+ +
+
+ +

◆ operator+()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator+ (Index idx)
+
+inline
+
+ +
+
+ +

◆ operator++() [1/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator& cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator++ ()
+
+inline
+
+ +
+
+ +

◆ operator++() [2/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator++ (int )
+
+inline
+
+ +
+
+ +

◆ operator+=()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator& cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator+= (Index idx)
+
+inline
+
+ +
+
+ +

◆ operator-()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator- (Index idx)
+
+inline
+
+ +
+
+ +

◆ operator--() [1/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator& cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator-- ()
+
+inline
+
+ +
+
+ +

◆ operator--() [2/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator-- (int )
+
+inline
+
+ +
+
+ +

◆ operator-=()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator& cutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator-= (Index idx)
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/classcutlass_1_1TensorRefBatchStrided_1_1ConstIterator-members.html b/docs/classcutlass_1_1TensorRefBatchStrided_1_1ConstIterator-members.html new file mode 100644 index 000000000..bb3876187 --- /dev/null +++ b/docs/classcutlass_1_1TensorRefBatchStrided_1_1ConstIterator-members.html @@ -0,0 +1,102 @@ + + + + + + + +Cutlass: Member List + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+
+
cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator Member List
+
+
+ +

This is the complete list of members for cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator, including all inherited members.

+ + + + + + + + + + + + + +
ConstIterator(TensorRefBatchStrided const &ref, LongIndex offset=0)cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator() constcutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator+(Index idx)cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator++()cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator++(int)cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator+=(Index idx)cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator-(Index idx)cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator-(ConstIterator const &it)cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator--()cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator--(int)cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
operator-=(Index idx)cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorinline
TensorRef typedefcutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator
+ + + + diff --git a/docs/classcutlass_1_1TensorRefBatchStrided_1_1ConstIterator.html b/docs/classcutlass_1_1TensorRefBatchStrided_1_1ConstIterator.html new file mode 100644 index 000000000..c3dbd9dfc --- /dev/null +++ b/docs/classcutlass_1_1TensorRefBatchStrided_1_1ConstIterator.html @@ -0,0 +1,476 @@ + + + + + + + +Cutlass: cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator Class Reference + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+ +
+
cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator Class Reference
+
+
+ +

Constant iterator over tensors implied by TensorRefBatchStrided. +

+ +

#include <tensor_ref_collection.h>

+ + + + + +

+Public Types

typedef Base TensorRef
 TensorRef returned by the iterator. More...
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

CUTLASS_HOST_DEVICE ConstIterator (TensorRefBatchStrided const &ref, LongIndex offset=0)
 Constructs a ConstIterator from a parent TensorRefBatchStrided. More...
 
CUTLASS_HOST_DEVICE TensorRefoperator () const
 Obtains a TensorRef pointed to by the iterator. More...
 
CUTLASS_HOST_DEVICE ConstIteratoroperator++ ()
 Advances the iterator to point to the next tensor. More...
 
CUTLASS_HOST_DEVICE ConstIterator operator++ (int)
 Advances the iterator to point to the next tensor. More...
 
CUTLASS_HOST_DEVICE ConstIterator operator+ (Index idx)
 Returns an iterator advanced by (idx) amount. More...
 
CUTLASS_HOST_DEVICE ConstIteratoroperator+= (Index idx)
 Advances this iterator by (idx) and returns a reference to self. More...
 
CUTLASS_HOST_DEVICE ConstIteratoroperator-- ()
 Moves to the previous tensor. More...
 
CUTLASS_HOST_DEVICE ConstIterator operator-- (int)
 Moves to the previous tensor. More...
 
CUTLASS_HOST_DEVICE ConstIterator operator- (Index idx)
 Returns an iterator moved forward by (idx) amount. More...
 
CUTLASS_HOST_DEVICE ConstIteratoroperator-= (Index idx)
 Moves this iterator by (idx) and returns a reference to self. More...
 
CUTLASS_HOST_DEVICE Stride operator- (ConstIterator const &it)
 Returns the difference in offset between two iterators. More...
 
+

Member Typedef Documentation

+ +

◆ TensorRef

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef Base cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::TensorRef
+
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ ConstIterator()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::ConstIterator (TensorRefBatchStrided const & ref,
LongIndex offset = 0 
)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ operator()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE TensorRef* cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator () const
+
+inline
+
+ +
+
+ +

◆ operator+()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator+ (Index idx)
+
+inline
+
+ +
+
+ +

◆ operator++() [1/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator& cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator++ ()
+
+inline
+
+ +
+
+ +

◆ operator++() [2/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator++ (int )
+
+inline
+
+ +
+
+ +

◆ operator+=()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator& cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator+= (Index idx)
+
+inline
+
+ +
+
+ +

◆ operator-() [1/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator- (Index idx)
+
+inline
+
+ +
+
+ +

◆ operator-() [2/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE Stride cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator- (ConstIterator const & it)
+
+inline
+
+ +
+
+ +

◆ operator--() [1/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator& cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator-- ()
+
+inline
+
+ +
+
+ +

◆ operator--() [2/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator-- (int )
+
+inline
+
+ +
+
+ +

◆ operator-=()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ConstIterator& cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIterator::operator-= (Index idx)
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/classcutlass_1_1TensorRef_3_01Storage___00_01Rank___00_01MapFunc___00_011_00_01Index___00_01LongIndex___01_4-members.html b/docs/classcutlass_1_1TensorRef_3_01Storage___00_01Rank___00_01MapFunc___00_011_00_01Index___00_01LongIndex___01_4-members.html new file mode 100644 index 000000000..8af74ab9b --- /dev/null +++ b/docs/classcutlass_1_1TensorRef_3_01Storage___00_01Rank___00_01MapFunc___00_011_00_01Index___00_01LongIndex___01_4-members.html @@ -0,0 +1,124 @@ + + + + + + + +Cutlass: Member List + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+
+
cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ > Member List
+
+
+ +

This is the complete list of members for cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
add_pointer_offset(LongIndex delta)cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
at(TensorCoord const &coord) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
at(LongIndex idx) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
const_ref() constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
ConstTensorRef typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >
Coord_t typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >
data() constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
good() constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
Index typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >
kRankcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >static
kStorageRankcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >static
leading_dim(int idx=0) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
LongIndex typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >
map(TensorCoord const &coord) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
MapFunc typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >
offset(TensorCoord const &coord) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
operator+(TensorCoord const &b) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
operator+=(TensorCoord const &b)cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
operator-(TensorCoord const &b) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
operator-=(TensorCoord const &b)cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
operator[](TensorCoord const &coord) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
operator[](LongIndex idx) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
Rankcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >static
reset(Storage *ptr=nullptr)cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
reset(Storage *ptr, StorageCoord const &stride)cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
Storage typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >
StorageCoord typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >
stride() constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
stride(int dim) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
TensorCoord typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >
TensorRef(Storage *ptr=nullptr)cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
TensorRef(Storage *ptr, StrideVector const &stride)cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
TensorRef(Storage *ptr, StorageCoord const &stride)cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
TensorRef(TensorRef< typename platform::remove_const< Storage >::type, kRank, MapFunc, kStorageRank, Index, LongIndex > const &ref)cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >inline
+ + + + diff --git a/docs/classcutlass_1_1TensorRef_3_01Storage___00_01Rank___00_01MapFunc___00_011_00_01Index___00_01LongIndex___01_4.html b/docs/classcutlass_1_1TensorRef_3_01Storage___00_01Rank___00_01MapFunc___00_011_00_01Index___00_01LongIndex___01_4.html new file mode 100644 index 000000000..2dfd10c99 --- /dev/null +++ b/docs/classcutlass_1_1TensorRef_3_01Storage___00_01Rank___00_01MapFunc___00_011_00_01Index___00_01LongIndex___01_4.html @@ -0,0 +1,1092 @@ + + + + + + + +Cutlass: cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ > Class Template Reference + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+ +
+
cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ > Class Template Reference
+
+
+ +

Specialization for rank=1 case with no internal StrideVector. +

+ +

#include <tensor_ref.h>

+ + + + +

+Classes

struct  StrideVector
 
+ + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Types

typedef Storage_ Storage
 Data type of individual access. More...
 
typedef MapFunc_ MapFunc
 Mapping function from logical coordinate to internal n-D array. More...
 
typedef Index_ Index
 Index type. More...
 
typedef LongIndex_ LongIndex
 Typically, strides in memory can be very large. More...
 
typedef Coord< kRankTensorCoord
 Coordinate in logical tensor space. More...
 
typedef Coord< kStorageRankStorageCoord
 Coordinate in storage n-D array. More...
 
typedef TensorRef< typename platform::remove_const< Storage >::type const, Rank_, MapFunc_, kStorageRank, Index_, LongIndex_ > ConstTensorRef
 Tensor reference to of constant value. More...
 
typedef TensorCoord Coord_t
 Coordinate in logical tensor space. More...
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

CUTLASS_HOST_DEVICE TensorRef (Storage *ptr=nullptr)
 Helper for 1-D memory. All higher ranks are projected onto the fastest changing rank. More...
 
CUTLASS_HOST_DEVICE TensorRef (Storage *ptr, StrideVector const &stride)
 Constructs from a single pointer and stride vector. More...
 
CUTLASS_HOST_DEVICE TensorRef (Storage *ptr, StorageCoord const &stride)
 
CUTLASS_HOST_DEVICE TensorRef (TensorRef< typename platform::remove_const< Storage >::type, kRank, MapFunc, kStorageRank, Index, LongIndex > const &ref)
 Enables conversion from TensorRef of non-const type. More...
 
CUTLASS_HOST_DEVICE ConstTensorRef const_ref () const
 Returns a reference to constant-valued tensor. More...
 
CUTLASS_HOST_DEVICE void reset (Storage *ptr=nullptr)
 Updates only the pointer. More...
 
CUTLASS_HOST_DEVICE void reset (Storage *ptr, StorageCoord const &stride)
 Updates the pointer, stride, and location within a TensorRef. More...
 
CUTLASS_HOST_DEVICE bool good () const
 Returns true if the TensorRef may be safely accessed. More...
 
CUTLASS_HOST_DEVICE Storagedata () const
 Returns the pointer to referenced data. More...
 
CUTLASS_HOST_DEVICE StorageCoord stride () const
 Returns the stride of the tensor. More...
 
CUTLASS_HOST_DEVICE Index stride (int dim) const
 Returns the stride of the tensor in the given dimension. More...
 
CUTLASS_HOST_DEVICE Index leading_dim (int idx=0) const
 Returns the maximum stride element as the 'leading dimension'. More...
 
CUTLASS_HOST_DEVICE StorageCoord map (TensorCoord const &coord) const
 Maps a logical coordinate to an n-D array in memory. More...
 
CUTLASS_HOST_DEVICE LongIndex offset (TensorCoord const &coord) const
 Computes the offset of an index from the origin of the tensor. More...
 
CUTLASS_HOST_DEVICE Storageat (TensorCoord const &coord) const
 Returns a reference to the element at a given Coord. More...
 
CUTLASS_HOST_DEVICE Storageat (LongIndex idx) const
 Returns a reference to the element at a given linear index. More...
 
CUTLASS_HOST_DEVICE Storageoperator[] (TensorCoord const &coord) const
 Returns a reference to the element at a given Coord. More...
 
CUTLASS_HOST_DEVICE Storageoperator[] (LongIndex idx) const
 Returns a reference to the element at a given linear index. More...
 
CUTLASS_HOST_DEVICE TensorRefadd_pointer_offset (LongIndex delta)
 Adds an offset to each pointer. More...
 
CUTLASS_HOST_DEVICE TensorRef operator+ (TensorCoord const &b) const
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorRefoperator+= (TensorCoord const &b)
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorRef operator- (TensorCoord const &b) const
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorRefoperator-= (TensorCoord const &b)
 Returns a TensorRef offset by a given amount. More...
 
+ + + + + + + + + + +

+Static Public Attributes

static int const kRank = Rank_
 Logical rank of tensor index space. More...
 
static int const kStorageRank = 1
 Rank of internal storage. More...
 
static int const Rank = kRank
 Logical rank of tensor index space. More...
 
+

Member Typedef Documentation

+ +

◆ ConstTensorRef

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + +
typedef TensorRef< typename platform::remove_const<Storage>::type const, Rank_, MapFunc_, kStorageRank, Index_, LongIndex_> cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::ConstTensorRef
+
+ +
+
+ +

◆ Coord_t

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + +
typedef TensorCoord cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::Coord_t
+
+ +
+
+ +

◆ Index

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + +
typedef Index_ cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::Index
+
+ +
+
+ +

◆ LongIndex

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + +
typedef LongIndex_ cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::LongIndex
+
+ +
+
+ +

◆ MapFunc

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + +
typedef MapFunc_ cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::MapFunc
+
+ +
+
+ +

◆ Storage

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + +
typedef Storage_ cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::Storage
+
+ +
+
+ +

◆ StorageCoord

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + +
typedef Coord<kStorageRank> cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::StorageCoord
+
+ +
+
+ +

◆ TensorCoord

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + +
typedef Coord<kRank> cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::TensorCoord
+
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ TensorRef() [1/4]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::TensorRef (Storageptr = nullptr)
+
+inline
+
+ +
+
+ +

◆ TensorRef() [2/4]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::TensorRef (Storageptr,
StrideVector const & stride 
)
+
+inline
+
+ +
+
+ +

◆ TensorRef() [3/4]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::TensorRef (Storageptr,
StorageCoord const & stride 
)
+
+inline
+
+

Constructs from a pointer and a stride vector of size kRank. If fastest changing stride is not 1, construction fails and subsequent calls to good() will return false.

+ +
+
+ +

◆ TensorRef() [4/4]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::TensorRef (TensorRef< typename platform::remove_const< Storage >::type, kRank, MapFunc, kStorageRank, Index, LongIndex > const & ref)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ add_pointer_offset()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorRef& cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::add_pointer_offset (LongIndex delta)
+
+inline
+
+ +
+
+ +

◆ at() [1/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE Storage& cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::at (TensorCoord const & coord) const
+
+inline
+
+ +
+
+ +

◆ at() [2/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE Storage& cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::at (LongIndex idx) const
+
+inline
+
+ +
+
+ +

◆ const_ref()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE ConstTensorRef cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::const_ref () const
+
+inline
+
+ +
+
+ +

◆ data()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE Storage* cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::data () const
+
+inline
+
+ +
+
+ +

◆ good()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE bool cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::good () const
+
+inline
+
+ +
+
+ +

◆ leading_dim()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE Index cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::leading_dim (int idx = 0) const
+
+inline
+
+ +
+
+ +

◆ map()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE StorageCoord cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::map (TensorCoord const & coord) const
+
+inline
+
+ +
+
+ +

◆ offset()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE LongIndex cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::offset (TensorCoord const & coord) const
+
+inline
+
+ +
+
+ +

◆ operator+()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorRef cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::operator+ (TensorCoord const & b) const
+
+inline
+
+ +
+
+ +

◆ operator+=()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorRef& cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::operator+= (TensorCoord const & b)
+
+inline
+
+ +
+
+ +

◆ operator-()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorRef cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::operator- (TensorCoord const & b) const
+
+inline
+
+ +
+
+ +

◆ operator-=()

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorRef& cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::operator-= (TensorCoord const & b)
+
+inline
+
+ +
+
+ +

◆ operator[]() [1/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE Storage& cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::operator[] (TensorCoord const & coord) const
+
+inline
+
+ +
+
+ +

◆ operator[]() [2/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE Storage& cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::operator[] (LongIndex idx) const
+
+inline
+
+ +
+
+ +

◆ reset() [1/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE void cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::reset (Storageptr = nullptr)
+
+inline
+
+ +
+
+ +

◆ reset() [2/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE void cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::reset (Storageptr,
StorageCoord const & stride 
)
+
+inline
+
+ +
+
+ +

◆ stride() [1/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE StorageCoord cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::stride () const
+
+inline
+
+ +
+
+ +

◆ stride() [2/2]

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE Index cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::stride (int dim) const
+
+inline
+
+ +
+
+

Member Data Documentation

+ +

◆ kRank

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + +
int const cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::kRank = Rank_
+
+static
+
+ +
+
+ +

◆ kStorageRank

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + +
int const cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::kStorageRank = 1
+
+static
+
+ +
+
+ +

◆ Rank

+ +
+
+
+template<typename Storage_ , int Rank_, typename MapFunc_ , typename Index_ , typename LongIndex_ >
+ + + + + +
+ + + + +
int const cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::Rank = kRank
+
+static
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/classcutlass_1_1TensorView-members.html b/docs/classcutlass_1_1TensorView-members.html index e9401f9cc..9f5c32535 100644 --- a/docs/classcutlass_1_1TensorView-members.html +++ b/docs/classcutlass_1_1TensorView-members.html @@ -73,51 +73,70 @@ $(function() {
-
cutlass::TensorView< T > Member List
+
cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > Member List
-

This is the complete list of members for cutlass::TensorView< T >, including all inherited members.

+

This is the complete list of members for cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >, including all inherited members.

- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
advance(Coord< Rank > const &b)cutlass::TensorRef< T, 4 >inline
at(Coord_t const &coord) constcutlass::TensorView< T >inline
at(Offset_t idx) constcutlass::TensorView< T >inline
Base typedefcutlass::TensorView< T >
const_ref()cutlass::TensorView< T >inline
ConstTensorRef_t typedefcutlass::TensorView< T >
contains(Coord_t const &coord) constcutlass::TensorView< T >inline
convert()cutlass::TensorRef< T, 4 >inline
Coord_t typedefcutlass::TensorView< T >
data() constcutlass::TensorView< T >inline
good() constcutlass::TensorView< T >inline
leading_dim() constcutlass::TensorRef< T, 4 >inline
offset(Coord_t const &coord) constcutlass::TensorView< T >inline
Offset_t typedefcutlass::TensorView< T >
operator+(Coord< Rank > const &b) constcutlass::TensorRef< T, 4 >inline
operator-(Coord< Rank > const &b) constcutlass::TensorRef< T, 4 >inline
operator=(TensorView const &_tensor)cutlass::TensorView< T >inline
operator[](Coord< Rank > const &coord) constcutlass::TensorView< T >inline
TensorRef< T, 4 >::operator[](int idx) constcutlass::TensorRef< T, 4 >inline
Rankcutlass::TensorView< T >static
ref()cutlass::TensorView< T >inline
ref() constcutlass::TensorView< T >inline
reset(TensorRef_t const &_ref=TensorRef_t(0), Coord_t const &_size=Coord_t())cutlass::TensorView< T >inline
TensorRef< T, 4 >::reset(Storage *ptr=nullptr, Coord< Rank > stride=Coord< Rank >(0))cutlass::TensorRef< T, 4 >inline
size() constcutlass::TensorView< T >inline
size(int dim) constcutlass::TensorView< T >inline
Storage typedefcutlass::TensorRef< T, 4 >
stride() constcutlass::TensorView< T >inline
stride(int dim) constcutlass::TensorView< T >inline
subview(Coord_t const &location, Coord_t size) constcutlass::TensorView< T >inline
TensorRef()cutlass::TensorRef< T, 4 >inline
TensorRef(Storage *ptr, Coord< Rank > stride)cutlass::TensorRef< T, 4 >inline
TensorRef_t typedefcutlass::TensorView< T >
TensorView()cutlass::TensorView< T >inline
TensorView(TensorRef_t const &_ref, Coord_t const &_size)cutlass::TensorView< T >inline
add_pointer_offset(LongIndex delta)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
at(TensorCoord const &coord) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
at(LongIndex idx) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
Base typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
capacity() constcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
const_ref() constcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
ConstTensorRef typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
ConstTensorRef_t typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
ConstTensorView typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
contains(TensorCoord const &coord) constcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
Coord_t typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
data() constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
good() constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
Index typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
kRankcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >static
kStorageRankcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >static
leading_dim(int idx=0) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
LongIndex typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
map(TensorCoord const &coord) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
MapFunc typedefcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
offset(TensorCoord const &coord) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
Offset_t typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
operator+(TensorCoord const &b) constcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
operator+=(TensorCoord const &b)cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
operator-(TensorCoord const &b) constcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
operator-=(TensorCoord const &b)cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
operator=(TensorView const &_tensor)cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
operator[](TensorCoord const &coord) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
operator[](LongIndex idx) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
Rankcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >static
ref() constcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
reset(Base const &_ref=Base(), TensorCoord const &_size=TensorCoord())cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
cutlass::TensorRef::reset(Storage *ptr=nullptr)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
cutlass::TensorRef::reset(Storage *ptr, StorageCoord const &stride)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
size() constcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
size(int dim) constcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
Storage typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
StorageCoord typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
stride() constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
stride(int dim) constcutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
StrideVector typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
subview(TensorCoord const &location, TensorCoord size) constcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
TensorCoord typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
TensorRef typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
cutlass::TensorRef::TensorRef(Storage *ptr=nullptr)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
cutlass::TensorRef::TensorRef(Storage *ptr, Index ldm)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
cutlass::TensorRef::TensorRef(Storage *ptr, StrideVector const &stride)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
cutlass::TensorRef::TensorRef(Storage *ptr, StorageCoord const &stride)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
cutlass::TensorRef::TensorRef(TensorRef< typename platform::remove_const< Storage >::type, kRank, MapFunc, kStorageRank, Index, LongIndex > const &ref)cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
TensorRef_t typedefcutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
TensorView()cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
TensorView(Base const &_ref, TensorCoord const &_size)cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
TensorView(Storage *ptr, StrideVector const &stride, TensorCoord const &size)cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
TensorView(Storage *ptr, StorageCoord const &stride, TensorCoord const &size)cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >inline
diff --git a/docs/classcutlass_1_1TensorView.html b/docs/classcutlass_1_1TensorView.html index 7dba23228..276d1077d 100644 --- a/docs/classcutlass_1_1TensorView.html +++ b/docs/classcutlass_1_1TensorView.html @@ -5,7 +5,7 @@ -Cutlass: cutlass::TensorView< T > Class Template Reference +Cutlass: cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > Class Template Reference @@ -78,242 +78,438 @@ $(function() { Static Public Attributes | List of all members
-
cutlass::TensorView< T > Class Template Reference
+
cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > Class Template Reference
-

Host-side reference implementation of tensor operations. +

Defines a view into a logical tensor.

#include <tensor_view.h>

-Inheritance diagram for cutlass::TensorView< T >:
+Inheritance diagram for cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >:
- - -cutlass::TensorRef< T, 4 > + + +cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
- - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

Public Types

typedef TensorRef< T, 4 > Base
 Reference and stride. More...
 
typedef Base TensorRef_t
 Reference and stride. More...
 
typedef TensorRef< T const, 4 > ConstTensorRef_t
 Reference to constant type. More...
 
typedef int Offset_t
 Type used to compute the offset of an element to the base of a tensor. More...
 
typedef Coord< RankCoord_t
 Coordinate into tensor. More...
 
- Public Types inherited from cutlass::TensorRef< T, 4 >
typedef T Storage
 Data type of individual access. More...
 
typedef TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > Base
 Base tensor reference. More...
 
typedef TensorRef< typename platform::remove_const< Storage_ >::type const, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > ConstTensorRef
 Tensor reference to of constant value. More...
 
typedef Base TensorRef
 Base tensor reference. More...
 
typedef Base::Storage Storage
 Storage type. More...
 
typedef Base::Index Index
 Index type. More...
 
typedef TensorRef::TensorCoord TensorCoord
 Coordinate in logical tensor space. More...
 
typedef TensorRef::StorageCoord StorageCoord
 Coordinate in storage n-D array. More...
 
typedef TensorRef::StrideVector StrideVector
 
typedef TensorView< typename platform::remove_const< Storage >::type const, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > ConstTensorView
 TensorView of constant value. More...
 
typedef TensorCoord Coord_t
 Coordinate in logical tensor space. More...
 
typedef Base::LongIndex Offset_t
 Type used to compute the offset of an element to the base of a tensor. More...
 
typedef TensorRef TensorRef_t
 Base class. More...
 
typedef TensorRef::ConstTensorRef ConstTensorRef_t
 TensorRef to const-valued type. More...
 
- Public Types inherited from cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
typedef Storage_ Storage
 Data type of individual access. More...
 
typedef MapFunc_ MapFunc
 Mapping function from logical coordinate to internal n-D array. More...
 
typedef Index_ Index
 Index type. More...
 
typedef LongIndex_ LongIndex
 Typically, strides in memory can be very large. More...
 
typedef Coord< kRankTensorCoord
 Coordinate in logical tensor space. More...
 
typedef Coord< kStorageRankStorageCoord
 Coordinate in storage n-D array. More...
 
typedef Coord< kStorageRank - 1 > StrideVector
 
typedef TensorRef< typename platform::remove_const< Storage >::type const, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > ConstTensorRef
 Tensor reference to of constant value. More...
 
typedef TensorCoord Coord_t
 Coordinate in logical tensor space. More...
 
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

Public Member Functions

CUTLASS_HOST_DEVICE TensorView ()
 Default constructor. More...
 
CUTLASS_HOST_DEVICE TensorView (TensorRef_t const &_ref, Coord_t const &_size)
 Constructs a Tensor_view from a TensorRef and size. More...
 
CUTLASS_HOST_DEVICE bool good () const
 Returns true if the Tensor_view is bound to some memory. More...
 
CUTLASS_HOST_DEVICE T * data () const
 Returns a pointer to data. More...
 
CUTLASS_HOST_DEVICE void reset (TensorRef_t const &_ref=TensorRef_t(0), Coord_t const &_size=Coord_t())
 Updates the reference and size of a Tensor_view object. More...
 
CUTLASS_HOST_DEVICE TensorRef_tref ()
 Accesses the tensor reference pointing to data. More...
 
CUTLASS_HOST_DEVICE ConstTensorRef_t const_ref ()
 
CUTLASS_HOST_DEVICE TensorRef_t const & ref () const
 Accesses the tensor reference pointing to data. More...
 
CUTLASS_HOST_DEVICE Coord_t const & size () const
 Accesses the size. More...
 
CUTLASS_HOST_DEVICE int size (int dim) const
 Accesses the size. More...
 
CUTLASS_HOST_DEVICE Coord_t const & stride () const
 Accesses the stride. More...
 
CUTLASS_HOST_DEVICE int const & stride (int dim) const
 Accesses the stride. More...
 
CUTLASS_HOST_DEVICE TensorViewoperator= (TensorView const &_tensor)
 Assigns the Tensor_view. More...
 
CUTLASS_HOST_DEVICE Offset_t offset (Coord_t const &coord) const
 Returns the index of an element. More...
 
CUTLASS_HOST_DEVICE bool contains (Coord_t const &coord) const
 Determines whether a location is within a tensor. More...
 
CUTLASS_HOST_DEVICE T & at (Coord_t const &coord) const
 Element-wise accessor. More...
 
T & operator[] (Coord< Rank > const &coord) const
 Element-wise accessor. More...
 
CUTLASS_HOST_DEVICE T & at (Offset_t idx) const
 Element-wise accessor. More...
 
CUTLASS_HOST_DEVICE TensorView< T > subview (Coord_t const &location, Coord_t size) const
 Returns a Tensor_view given location and size quantities. More...
 
- Public Member Functions inherited from cutlass::TensorRef< T, 4 >
CUTLASS_HOST_DEVICE TensorRef ()
 Default ctor. More...
 
CUTLASS_HOST_DEVICE TensorRef (Storage *ptr, Coord< Rank > stride)
 Constructs from a pointer, size, and stride. More...
 
CUTLASS_HOST_DEVICE void reset (Storage *ptr=nullptr, Coord< Rank > stride=Coord< Rank >(0))
 Updates the pointer, stride, and location within a TensorRef. More...
 
TensorRef< T, Rankconvert ()
 Conversion function. More...
 
CUTLASS_HOST_DEVICE bool good () const
 Returns true if the TensorRef may be safely accessed. More...
 
CUTLASS_HOST_DEVICE Storagedata () const
 Returns the pointer to referenced data. More...
 
CUTLASS_HOST_DEVICE Coord< Rank > const & stride () const
 Returns the stride of the tensor. More...
 
CUTLASS_HOST_DEVICE int const & stride (int dim) const
 Returns the stride of the tensor in the given dimension. More...
 
CUTLASS_HOST_DEVICE int leading_dim () const
 Returns the maximum stride element as the 'leading dimension'. More...
 
CUTLASS_HOST_DEVICE long long offset (Coord< Rank > const &coord) const
 Computes the offset of an index from the origin of the tensor. More...
 
CUTLASS_HOST_DEVICE Storageat (Coord< Rank > const &coord) const
 Returns a reference to the element at a given Coord. More...
 
CUTLASS_HOST_DEVICE Storageat (int idx) const
 Returns a reference to the element at a given Coord. More...
 
Storageoperator[] (Coord< Rank > const &coord) const
 Element-wise accessor. More...
 
Storageoperator[] (int idx) const
 Element-wise accessor. More...
 
CUTLASS_HOST_DEVICE TensorRefadvance (Coord< Rank > const &b)
 Adds an offset to the pointer. More...
 
CUTLASS_HOST_DEVICE TensorRef operator+ (Coord< Rank > const &b) const
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorRef operator- (Coord< Rank > const &b) const
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorView ()
 Default constructor. More...
 
CUTLASS_HOST_DEVICE TensorView (Base const &_ref, TensorCoord const &_size)
 Constructs a TensorView from a TensorRef and size. More...
 
CUTLASS_HOST_DEVICE TensorView (Storage *ptr, StrideVector const &stride, TensorCoord const &size)
 Constructs a TensorView from a pointer, a stride vector, and size. More...
 
CUTLASS_HOST_DEVICE TensorView (Storage *ptr, StorageCoord const &stride, TensorCoord const &size)
 Constructs a TensorView from a pointer, a stride vector, and size. More...
 
CUTLASS_HOST_DEVICE void reset (Base const &_ref=Base(), TensorCoord const &_size=TensorCoord())
 Updates the reference and size of a Tensor_view object. More...
 
CUTLASS_HOST_DEVICE TensorCoord const & size () const
 Accesses the size. More...
 
CUTLASS_HOST_DEVICE Index size (int dim) const
 Accesses the size. More...
 
CUTLASS_HOST_DEVICE TensorViewoperator= (TensorView const &_tensor)
 Assigns the Tensor_view. More...
 
CUTLASS_HOST_DEVICE bool contains (TensorCoord const &coord) const
 Determines whether a location is within a tensor. More...
 
CUTLASS_HOST_DEVICE TensorRef ref () const
 Returns a TensorRef pointing to the first element of the tensor. More...
 
CUTLASS_HOST_DEVICE ConstTensorRef const_ref () const
 Returns a TensorRef pointing to the first element of the tensor. More...
 
CUTLASS_HOST_DEVICE TensorView subview (TensorCoord const &location, TensorCoord size) const
 Returns a Tensor_view given location and size quantities. More...
 
CUTLASS_HOST_DEVICE size_t capacity () const
 Returns the number of scalar elements needed to store tensor. More...
 
CUTLASS_HOST_DEVICE TensorView operator+ (TensorCoord const &b) const
 Returns a TensorView offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorViewoperator+= (TensorCoord const &b)
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorView operator- (TensorCoord const &b) const
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorViewoperator-= (TensorCoord const &b)
 Returns a TensorRef offset by a given amount. More...
 
- Public Member Functions inherited from cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
CUTLASS_HOST_DEVICE TensorRef (Storage *ptr=nullptr)
 Helper for 1-D memory. All higher ranks are projected onto the fastest changing rank. More...
 
CUTLASS_HOST_DEVICE TensorRef (Storage *ptr, Index ldm)
 Helper to construct from a pointer and single stride element for 2-D pitch linear memory. More...
 
CUTLASS_HOST_DEVICE TensorRef (Storage *ptr, StrideVector const &stride)
 Constructs from a single pointer and stride vector. More...
 
CUTLASS_HOST_DEVICE TensorRef (Storage *ptr, StorageCoord const &stride)
 
CUTLASS_HOST_DEVICE TensorRef (TensorRef< typename platform::remove_const< Storage >::type, kRank, MapFunc, kStorageRank, Index, LongIndex > const &ref)
 Enables conversion from TensorRef of non-const type. More...
 
CUTLASS_HOST_DEVICE ConstTensorRef const_ref () const
 Returns a reference to constant-valued tensor. More...
 
CUTLASS_HOST_DEVICE void reset (Storage *ptr=nullptr)
 Updates only the pointer. More...
 
CUTLASS_HOST_DEVICE void reset (Storage *ptr, StorageCoord const &stride)
 Updates the pointer, stride, and location within a TensorRef. More...
 
CUTLASS_HOST_DEVICE bool good () const
 Returns true if the TensorRef may be safely accessed. More...
 
CUTLASS_HOST_DEVICE Storagedata () const
 Returns the pointer to referenced data. More...
 
CUTLASS_HOST_DEVICE StorageCoord stride () const
 Returns the stride of the tensor. More...
 
CUTLASS_HOST_DEVICE Index stride (int dim) const
 Returns the stride of the tensor in the given dimension. More...
 
CUTLASS_HOST_DEVICE Index leading_dim (int idx=0) const
 Returns the maximum stride element as the 'leading dimension'. More...
 
CUTLASS_HOST_DEVICE StorageCoord map (TensorCoord const &coord) const
 Maps a logical coordinate to an n-D array in memory. More...
 
CUTLASS_HOST_DEVICE LongIndex offset (TensorCoord const &coord) const
 Computes the offset of an index from the origin of the tensor. More...
 
CUTLASS_HOST_DEVICE Storageat (TensorCoord const &coord) const
 Returns a reference to the element at a given Coord. More...
 
CUTLASS_HOST_DEVICE Storageat (LongIndex idx) const
 Returns a reference to the element at a given linear index. More...
 
CUTLASS_HOST_DEVICE Storageoperator[] (TensorCoord const &coord) const
 Returns a reference to the element at a given Coord. More...
 
CUTLASS_HOST_DEVICE Storageoperator[] (LongIndex idx) const
 Returns a reference to the element at a given linear index. More...
 
CUTLASS_HOST_DEVICE TensorRefadd_pointer_offset (LongIndex delta)
 Adds an offset to each pointer. More...
 
CUTLASS_HOST_DEVICE TensorRef operator+ (TensorCoord const &b) const
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorRefoperator+= (TensorCoord const &b)
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorRef operator- (TensorCoord const &b) const
 Returns a TensorRef offset by a given amount. More...
 
CUTLASS_HOST_DEVICE TensorRefoperator-= (TensorCoord const &b)
 Returns a TensorRef offset by a given amount. More...
 
- - - - - - - + + + + + + + + + + + + +

Static Public Attributes

static int const Rank = TensorRef_t::Rank
 Rank of tensor. More...
 
- Static Public Attributes inherited from cutlass::TensorRef< T, 4 >
static int const Rank
 Rank of tensor. More...
 
static int const Rank = Base::kRank
 Logical rank of tensor index space. More...
 
- Static Public Attributes inherited from cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
static int const kRank = Rank_
 Logical rank of tensor index space. More...
 
static int const kStorageRank = StorageRank_
 Rank of internal storage. More...
 
static int const Rank = kRank
 Logical rank of tensor index space. More...
 

Member Typedef Documentation

- -

◆ Base

+ +

◆ Base

-template<typename T>
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
- +
typedef TensorRef<T, 4> cutlass::TensorView< T >::Basetypedef TensorRef<Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::Base
- -

◆ ConstTensorRef_t

+ +

◆ ConstTensorRef

-template<typename T>
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
- +
typedef TensorRef<T const, 4> cutlass::TensorView< T >::ConstTensorRef_ttypedef TensorRef< typename platform::remove_const<Storage_>::type const, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstTensorRef
- -

◆ Coord_t

+ +

◆ ConstTensorRef_t

-template<typename T>
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
- +
typedef Coord<Rank> cutlass::TensorView< T >::Coord_ttypedef TensorRef::ConstTensorRef cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstTensorRef_t
- -

◆ Offset_t

+ +

◆ ConstTensorView

-template<typename T>
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
- +
typedef int cutlass::TensorView< T >::Offset_ttypedef TensorView< typename platform::remove_const<Storage>::type const, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_> cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstTensorView
- -

◆ TensorRef_t

+ +

◆ Coord_t

-template<typename T>
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
- + + +
typedef Base cutlass::TensorView< T >::TensorRef_ttypedef TensorCoord cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::Coord_t
+
+ +
+ + +

◆ Index

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef Base::Index cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::Index
+
+ +
+
+ +

◆ Offset_t

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef Base::LongIndex cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::Offset_t
+
+ +
+
+ +

◆ Storage

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef Base::Storage cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::Storage
+
+ +
+
+ +

◆ StorageCoord

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef TensorRef::StorageCoord cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::StorageCoord
+
+ +
+
+ +

◆ StrideVector

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef TensorRef::StrideVector cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::StrideVector
+
+

Stride vector in storage coordinate space Least significant stride is = 1 and not stored

+ +
+
+ +

◆ TensorCoord

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef TensorRef::TensorCoord cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorCoord
+
+ +
+
+ +

◆ TensorRef

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + +
typedef Base cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorRef
+
+ +
+
+ +

◆ TensorRef_t

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + +
typedef TensorRef cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorRef_t
@@ -321,19 +517,19 @@ template<typename T>

Constructor & Destructor Documentation

- -

◆ TensorView() [1/2]

+ +

◆ TensorView() [1/4]

-template<typename T>
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + + + + + + + + + + + @@ -109,6 +124,9 @@ Files + + + @@ -124,9 +142,6 @@ Files - - - @@ -145,12 +160,21 @@ Files + + + + + + + + + @@ -170,7 +194,7 @@ Files diff --git a/docs/dir_c5917a9a879e9a6c73eaf5237444ab84.html b/docs/dir_c5917a9a879e9a6c73eaf5237444ab84.html index a66eb22fa..9011cf40c 100644 --- a/docs/dir_c5917a9a879e9a6c73eaf5237444ab84.html +++ b/docs/dir_c5917a9a879e9a6c73eaf5237444ab84.html @@ -79,12 +79,16 @@ $(function() {
- + @@ -348,27 +544,27 @@ template<typename T> - -

◆ TensorView() [2/2]

+ +

◆ TensorView() [2/4]

-template<typename T>
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
CUTLASS_HOST_DEVICE cutlass::TensorView< T >::TensorView CUTLASS_HOST_DEVICE cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorView ( )
diff --git a/docs/debug_8h_source.html b/docs/debug_8h_source.html index 881b4e3f0..c404b4110 100644 --- a/docs/debug_8h_source.html +++ b/docs/debug_8h_source.html @@ -81,7 +81,7 @@ $(function() { diff --git a/docs/dgemm__traits_8h.html b/docs/dgemm__traits_8h.html index eebc2f364..ac6d33b0c 100644 --- a/docs/dgemm__traits_8h.html +++ b/docs/dgemm__traits_8h.html @@ -82,21 +82,21 @@ $(function() {

Defines structural traits of double-precision GEMM. More...

-
- + - + - + @@ -386,332 +582,34 @@ template<typename T> -

Member Function Documentation

- -

◆ at() [1/2]

+ +

◆ TensorView() [3/4]

-template<typename T>
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
CUTLASS_HOST_DEVICE cutlass::TensorView< T >::TensorView CUTLASS_HOST_DEVICE cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorView (TensorRef_t const & Base const &  _ref,
Coord_t const & TensorCoord const &  _size 
- - -
- + - - - - -
CUTLASS_HOST_DEVICE T& cutlass::TensorView< T >::at CUTLASS_HOST_DEVICE cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorView (Coord_t const & coord) const
-
-inline
-
- -
- - -

◆ at() [2/2]

- -
-
-
-template<typename T>
- - - - - -
- - - - - - - - -
CUTLASS_HOST_DEVICE T& cutlass::TensorView< T >::at (Offset_t idx) const
-
-inline
-
- -
-
- -

◆ const_ref()

- -
-
-
-template<typename T>
- - - - - -
- - - - - - - -
CUTLASS_HOST_DEVICE ConstTensorRef_t cutlass::TensorView< T >::const_ref ()
-
-inline
-
- -
-
- -

◆ contains()

- -
-
-
-template<typename T>
- - - - - -
- - - - - - - - -
CUTLASS_HOST_DEVICE bool cutlass::TensorView< T >::contains (Coord_t const & coord) const
-
-inline
-
- -
-
- -

◆ data()

- -
-
-
-template<typename T>
- - - - - -
- - - - - - - -
CUTLASS_HOST_DEVICE T* cutlass::TensorView< T >::data () const
-
-inline
-
- -
-
- -

◆ good()

- -
-
-
-template<typename T>
- - - - - -
- - - - - - - -
CUTLASS_HOST_DEVICE bool cutlass::TensorView< T >::good () const
-
-inline
-
- -
-
- -

◆ offset()

- -
-
-
-template<typename T>
- - - - - -
- - - - - - - - -
CUTLASS_HOST_DEVICE Offset_t cutlass::TensorView< T >::offset (Coord_t const & coord) const
-
-inline
-
- -
-
- -

◆ operator=()

- -
-
-
-template<typename T>
- - - - - -
- - - - - - - - -
CUTLASS_HOST_DEVICE TensorView& cutlass::TensorView< T >::operator= (TensorView< T > const & _tensor)
-
-inline
-
- -
-
- -

◆ operator[]()

- -
-
-
-template<typename T>
- - - - - -
- - - - - - - - -
T& cutlass::TensorView< T >::operator[] (Coord< Rank > const & coord) const
-
-inline
-
- -
-
- -

◆ ref() [1/2]

- -
-
-
-template<typename T>
- - - - - -
- - - - - - - -
CUTLASS_HOST_DEVICE TensorRef_t& cutlass::TensorView< T >::ref ()
-
-inline
-
- -
-
- -

◆ ref() [2/2]

- -
-
-
-template<typename T>
- - - - - -
- - - - - - - -
CUTLASS_HOST_DEVICE TensorRef_t const& cutlass::TensorView< T >::ref () const
-
-inline
-
- -
-
- -

◆ reset()

- -
-
-
-template<typename T>
- - -
- - - - - - + + - - + + + + + + + + @@ -728,19 +626,64 @@ template<typename T> - -

◆ size() [1/2]

+ +

◆ TensorView() [4/4]

-template<typename T>
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
CUTLASS_HOST_DEVICE void cutlass::TensorView< T >::reset (TensorRef_t const & _ref = TensorRef_t(0), Storageptr,
Coord_t const & _size = Coord_t() StrideVector const & stride,
TensorCoord const & size 
+ + +
- + + + + + + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE Coord_t const& cutlass::TensorView< T >::size CUTLASS_HOST_DEVICE cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::TensorView (Storageptr,
StorageCoord const & stride,
TensorCoord const & size 
)
+
+inline
+
+ +
+ +

Member Function Documentation

+ +

◆ capacity()

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + +
+ + + @@ -755,19 +698,306 @@ template<typename T> - -

◆ size() [2/2]

+ +

◆ const_ref()

-template<typename T>
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
CUTLASS_HOST_DEVICE size_t cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::capacity ( ) const
+ + +
- + + + + + +
CUTLASS_HOST_DEVICE int cutlass::TensorView< T >::size CUTLASS_HOST_DEVICE ConstTensorRef cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::const_ref () const
+
+inline
+
+ +
+ + +

◆ contains()

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE bool cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::contains (TensorCoord const & coord) const
+
+inline
+
+ +
+
+ +

◆ operator+()

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorView cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::operator+ (TensorCoord const & b) const
+
+inline
+
+ +
+
+ +

◆ operator+=()

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorView& cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::operator+= (TensorCoord const & b)
+
+inline
+
+ +
+
+ +

◆ operator-()

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorView cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::operator- (TensorCoord const & b) const
+
+inline
+
+ +
+
+ +

◆ operator-=()

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorView& cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::operator-= (TensorCoord const & b)
+
+inline
+
+ +
+
+ +

◆ operator=()

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE TensorView& cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::operator= (TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > const & _tensor)
+
+inline
+
+ +
+
+ +

◆ ref()

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE TensorRef cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ref () const
+
+inline
+
+ +
+
+ +

◆ reset()

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE void cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::reset (Base const & _ref = Base(),
TensorCoord const & _size = TensorCoord() 
)
+
+inline
+
+ +
+
+ +

◆ size() [1/2]

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE TensorCoord const& cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::size () const
+
+inline
+
+ +
+
+ +

◆ size() [2/2]

+ +
+
+
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
+ + +
+ + + @@ -783,82 +1013,27 @@ template<typename T> - -

◆ stride() [1/2]

+ +

◆ subview()

-template<typename T>
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
CUTLASS_HOST_DEVICE Index cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::size ( int  dim)
- - -
- + - - - -
CUTLASS_HOST_DEVICE Coord_t const& cutlass::TensorView< T >::stride CUTLASS_HOST_DEVICE TensorView cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::subview () const
-
-inline
-
- -
- - -

◆ stride() [2/2]

- -
-
-
-template<typename T>
- - - - - -
- - - - - - - - -
CUTLASS_HOST_DEVICE int const& cutlass::TensorView< T >::stride (int dim) const
-
-inline
-
- -
-
- -

◆ subview()

- -
-
-
-template<typename T>
- - -
- - - - - + - + @@ -877,19 +1052,19 @@ template<typename T>

Member Data Documentation

- -

◆ Rank

+ +

◆ Rank

-template<typename T>
+template<typename Storage_ , int Rank_ = 4, typename MapFunc_ = IdentityTensorMapFunc<Rank_>, int StorageRank_ = MapFunc_::kStorageRank, typename Index_ = int, typename LongIndex_ = long long>
CUTLASS_HOST_DEVICE TensorView<T> cutlass::TensorView< T >::subview (Coord_t const & TensorCoord const &  location,
Coord_t TensorCoord  size 
@@ -907,7 +1082,7 @@ template<typename T> diff --git a/docs/classcutlass_1_1TensorView.png b/docs/classcutlass_1_1TensorView.png index 40500e8a3..46861ac91 100644 Binary files a/docs/classcutlass_1_1TensorView.png and b/docs/classcutlass_1_1TensorView.png differ diff --git a/docs/classcutlass_1_1ZipTileIterator-members.html b/docs/classcutlass_1_1ZipTileIterator-members.html new file mode 100644 index 000000000..6de74a494 --- /dev/null +++ b/docs/classcutlass_1_1ZipTileIterator-members.html @@ -0,0 +1,125 @@ + + + + + + + +Cutlass: Member List + + + + + + + + + + +
+
+
- +
int const cutlass::TensorView< T >::Rank = TensorRef_t::Rankint const cutlass::TensorView< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::Rank = Base::kRank
+ + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+ + + + + + + + + +
+
+ + +
+ +
+ + + +
+
+
cutlass::ZipTileIterator< First_, Second_ > Member List
+
+
+ +

This is the complete list of members for cutlass::ZipTileIterator< First_, Second_ >, including all inherited members.

+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
add_pointer_offset(Index offset)cutlass::ZipTileIterator< First_, Second_ >inline
decrement(int count=1)cutlass::ZipTileIterator< First_, Second_ >inline
First typedefcutlass::ZipTileIterator< First_, Second_ >
firstcutlass::ZipTileIterator< First_, Second_ >
Fragment typedefcutlass::ZipTileIterator< First_, Second_ >
increment(int count=1)cutlass::ZipTileIterator< First_, Second_ >inline
Index typedefcutlass::ZipTileIterator< First_, Second_ >
initialize_predicates(PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))cutlass::ZipTileIterator< First_, Second_ >inline
initialize_predicates(PredicateIterator predicate_it, PredicateFunctor const &functor, Coord< 3 > const &block_offset)cutlass::ZipTileIterator< First_, Second_ >inline
load(Fragment &fragment) constcutlass::ZipTileIterator< First_, Second_ >inline
load(Fragment &fragment, Coord< 4 > const &offset) constcutlass::ZipTileIterator< First_, Second_ >inline
load(Fragment &fragment, PredicateIterator pred_it) constcutlass::ZipTileIterator< First_, Second_ >inline
load_post_increment(Fragment &fragment)cutlass::ZipTileIterator< First_, Second_ >inline
load_post_increment(Fragment &fragment, Coord< 4 > const &offset)cutlass::ZipTileIterator< First_, Second_ >inline
load_post_increment(Fragment &fragment, PredicateIterator pred_it)cutlass::ZipTileIterator< First_, Second_ >inline
operator++()cutlass::ZipTileIterator< First_, Second_ >inline
operator+=(int count)cutlass::ZipTileIterator< First_, Second_ >inline
operator+=(Coord< 3 > const &offset)cutlass::ZipTileIterator< First_, Second_ >inline
operator--()cutlass::ZipTileIterator< First_, Second_ >inline
operator-=(int count)cutlass::ZipTileIterator< First_, Second_ >inline
PredicateVector typedefcutlass::ZipTileIterator< First_, Second_ >
secondcutlass::ZipTileIterator< First_, Second_ >
Second typedefcutlass::ZipTileIterator< First_, Second_ >
store(Fragment const &fragment) constcutlass::ZipTileIterator< First_, Second_ >inline
store(Fragment const &fragment, Coord< 4 > const &offset) constcutlass::ZipTileIterator< First_, Second_ >inline
store(Fragment const &fragment, PredicateIterator pred_it) constcutlass::ZipTileIterator< First_, Second_ >inline
store_post_increment(Fragment const &fragment)cutlass::ZipTileIterator< First_, Second_ >inline
store_post_increment(Fragment const &fragment, Coord< 4 > const &offset)cutlass::ZipTileIterator< First_, Second_ >inline
store_post_increment(Fragment const &fragment, PredicateIterator pred_it)cutlass::ZipTileIterator< First_, Second_ >inline
TensorRef typedefcutlass::ZipTileIterator< First_, Second_ >
ZipTileIterator()cutlass::ZipTileIterator< First_, Second_ >inline
ZipTileIterator(Params const &_params, Coord< 3 > const &threadblock_offset=make_Coord(0, 0, 0))cutlass::ZipTileIterator< First_, Second_ >inline
ZipTileIterator(First const &_first, Second const &_second)cutlass::ZipTileIterator< First_, Second_ >inline
ZipTileIterator(TensorRef const &ref)cutlass::ZipTileIterator< First_, Second_ >inline
ZipTileIterator(Params const &_params, TensorRef const &ref)cutlass::ZipTileIterator< First_, Second_ >inline
+ + + + diff --git a/docs/classcutlass_1_1ZipTileIterator.html b/docs/classcutlass_1_1ZipTileIterator.html new file mode 100644 index 000000000..7cf7a392b --- /dev/null +++ b/docs/classcutlass_1_1ZipTileIterator.html @@ -0,0 +1,1290 @@ + + + + + + + +Cutlass: cutlass::ZipTileIterator< First_, Second_ > Class Template Reference + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+ +
+
cutlass::ZipTileIterator< First_, Second_ > Class Template Reference
+
+
+ +

Constructs an iterator from a pair of iterators. +

+ +

#include <zip_tile_iterator.h>

+ + + + + +

+Classes

struct  Params
 Params object. More...
 
+ + + + + + + + + + + + + + + + + + + +

+Public Types

typedef First_ First
 First iterator type. More...
 
typedef Second_ Second
 Second iterator type. More...
 
typedef ZipFragment< typename First::Fragment, typename Second::Fragment > Fragment
 Fragment type. More...
 
typedef First::PredicateVector PredicateVector
 Predicate vector. More...
 
typedef First::Index Index
 Index type. More...
 
typedef ZipTensorRef< typename First::TensorRef, typename Second::TensorRef > TensorRef
 Tensor reference. More...
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

CUTLASS_DEVICE ZipTileIterator ()
 Default constructor. More...
 
CUTLASS_DEVICE ZipTileIterator (Params const &_params, Coord< 3 > const &threadblock_offset=make_Coord(0, 0, 0))
 Constructs a zip iterator from params. More...
 
CUTLASS_DEVICE ZipTileIterator (First const &_first, Second const &_second)
 Constructs a zip iterator from iterator instances. More...
 
CUTLASS_DEVICE ZipTileIterator (TensorRef const &ref)
 Constructs a zip iterator from iterator instances. More...
 
CUTLASS_DEVICE ZipTileIterator (Params const &_params, TensorRef const &ref)
 Constructs a zip iterator from iterator instances. More...
 
template<typename PredicateIterator >
CUTLASS_HOST_DEVICE void initialize_predicates (PredicateIterator predicate_it, Coord< 3 > const &bounds, Coord< 3 > const &block_offset=make_Coord(0, 0, 0))
 Initializes a predicate vector using a RegularTilePredicateFunctor. More...
 
template<typename PredicateIterator , typename PredicateFunctor >
CUTLASS_HOST_DEVICE void initialize_predicates (PredicateIterator predicate_it, PredicateFunctor const &functor, Coord< 3 > const &block_offset)
 Initializes a predicate vector using an arbitrary predicate functor. More...
 
template<typename Fragment >
CUTLASS_DEVICE void load_post_increment (Fragment &fragment)
 Loads a fragment and increments without predicates. More...
 
template<typename Fragment >
CUTLASS_DEVICE void load_post_increment (Fragment &fragment, Coord< 4 > const &offset)
 Loads a fragment and increments without predicates. More...
 
template<typename Fragment >
CUTLASS_DEVICE void load (Fragment &fragment) const
 Loads a fragment without predicates. More...
 
template<typename Fragment >
CUTLASS_DEVICE void load (Fragment &fragment, Coord< 4 > const &offset) const
 Loads a fragment without predicates. More...
 
template<typename Fragment >
CUTLASS_DEVICE void store_post_increment (Fragment const &fragment)
 Stores a fragment and increments without predicates. More...
 
template<typename Fragment >
CUTLASS_DEVICE void store_post_increment (Fragment const &fragment, Coord< 4 > const &offset)
 Stores a fragment and increments without predicates. More...
 
template<typename Fragment >
CUTLASS_DEVICE void store (Fragment const &fragment) const
 Stores a fragment without predicates. More...
 
template<typename Fragment >
CUTLASS_DEVICE void store (Fragment const &fragment, Coord< 4 > const &offset) const
 Stores a fragment without predicates. More...
 
template<typename Fragment , typename PredicateIterator >
CUTLASS_DEVICE void load_post_increment (Fragment &fragment, PredicateIterator pred_it)
 Loads a fragment and increments, using predicates. More...
 
template<typename Fragment , typename PredicateIterator >
CUTLASS_DEVICE void load (Fragment &fragment, PredicateIterator pred_it) const
 Loads a fragment with predicates. More...
 
template<typename Fragment , typename PredicateIterator >
CUTLASS_DEVICE void store_post_increment (Fragment const &fragment, PredicateIterator pred_it)
 Loads a fragment and increments, using predicates. More...
 
template<typename Fragment , typename PredicateIterator >
CUTLASS_DEVICE void store (Fragment const &fragment, PredicateIterator pred_it) const
 Loads a fragment with predicates. More...
 
CUTLASS_DEVICE ZipTileIteratorincrement (int count=1)
 Increments store iterator to next tile. More...
 
CUTLASS_DEVICE ZipTileIteratoroperator++ ()
 Increments to next tile. More...
 
CUTLASS_DEVICE ZipTileIteratoroperator+= (int count)
 
CUTLASS_DEVICE ZipTileIteratoroperator+= (Coord< 3 > const &offset)
 Adds a vector offset to the underlying iterators. More...
 
CUTLASS_DEVICE ZipTileIteratordecrement (int count=1)
 Increments store iterator to previous tile. More...
 
CUTLASS_DEVICE ZipTileIteratoroperator-- ()
 Increments to subsequent tile. More...
 
CUTLASS_DEVICE ZipTileIteratoroperator-= (int count)
 Decrements to previous tile. More...
 
CUTLASS_DEVICE void add_pointer_offset (Index offset)
 Adds an offset to both iterators. More...
 
+ + + + + + + +

+Public Attributes

First first
 First iterator. More...
 
Second second
 Second iterator. More...
 
+

Member Typedef Documentation

+ +

◆ First

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + +
typedef First_ cutlass::ZipTileIterator< First_, Second_ >::First
+
+ +
+
+ +

◆ Fragment

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + +
typedef ZipFragment<typename First::Fragment, typename Second::Fragment> cutlass::ZipTileIterator< First_, Second_ >::Fragment
+
+ +
+
+ +

◆ Index

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + +
typedef First::Index cutlass::ZipTileIterator< First_, Second_ >::Index
+
+ +
+
+ +

◆ PredicateVector

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + +
typedef First::PredicateVector cutlass::ZipTileIterator< First_, Second_ >::PredicateVector
+
+ +
+
+ +

◆ Second

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + +
typedef Second_ cutlass::ZipTileIterator< First_, Second_ >::Second
+
+ +
+
+ +

◆ TensorRef

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + +
typedef ZipTensorRef< typename First::TensorRef, typename Second::TensorRef> cutlass::ZipTileIterator< First_, Second_ >::TensorRef
+
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ ZipTileIterator() [1/5]

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + +
CUTLASS_DEVICE cutlass::ZipTileIterator< First_, Second_ >::ZipTileIterator ()
+
+inline
+
+ +
+
+ +

◆ ZipTileIterator() [2/5]

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_DEVICE cutlass::ZipTileIterator< First_, Second_ >::ZipTileIterator (Params const & _params,
Coord< 3 > const & threadblock_offset = make_Coord(0, 0, 0) 
)
+
+inline
+
+ +
+
+ +

◆ ZipTileIterator() [3/5]

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_DEVICE cutlass::ZipTileIterator< First_, Second_ >::ZipTileIterator (First const & _first,
Second const & _second 
)
+
+inline
+
+ +
+
+ +

◆ ZipTileIterator() [4/5]

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_DEVICE cutlass::ZipTileIterator< First_, Second_ >::ZipTileIterator (TensorRef const & ref)
+
+inline
+
+ +
+
+ +

◆ ZipTileIterator() [5/5]

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_DEVICE cutlass::ZipTileIterator< First_, Second_ >::ZipTileIterator (Params const & _params,
TensorRef const & ref 
)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ add_pointer_offset()

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::add_pointer_offset (Index offset)
+
+inline
+
+ +
+
+ +

◆ decrement()

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_DEVICE ZipTileIterator& cutlass::ZipTileIterator< First_, Second_ >::decrement (int count = 1)
+
+inline
+
+ +
+
+ +

◆ increment()

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_DEVICE ZipTileIterator& cutlass::ZipTileIterator< First_, Second_ >::increment (int count = 1)
+
+inline
+
+ +
+
+ +

◆ initialize_predicates() [1/2]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename PredicateIterator >
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::initialize_predicates (PredicateIterator predicate_it,
Coord< 3 > const & bounds,
Coord< 3 > const & block_offset = make_Coord(0,                                                                                           0,                                                                                           0) 
)
+
+inline
+
+ +
+
+ +

◆ initialize_predicates() [2/2]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename PredicateIterator , typename PredicateFunctor >
+ + + + + +
+ + + + + + + + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::initialize_predicates (PredicateIterator predicate_it,
PredicateFunctor const & functor,
Coord< 3 > const & block_offset 
)
+
+inline
+
+ +
+
+ +

◆ load() [1/3]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename Fragment >
+ + + + + +
+ + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::load (Fragmentfragment) const
+
+inline
+
+ +
+
+ +

◆ load() [2/3]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename Fragment >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::load (Fragmentfragment,
Coord< 4 > const & offset 
) const
+
+inline
+
+ +
+
+ +

◆ load() [3/3]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename Fragment , typename PredicateIterator >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::load (Fragmentfragment,
PredicateIterator pred_it 
) const
+
+inline
+
+ +
+
+ +

◆ load_post_increment() [1/3]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename Fragment >
+ + + + + +
+ + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::load_post_increment (Fragmentfragment)
+
+inline
+
+ +
+
+ +

◆ load_post_increment() [2/3]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename Fragment >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::load_post_increment (Fragmentfragment,
Coord< 4 > const & offset 
)
+
+inline
+
+ +
+
+ +

◆ load_post_increment() [3/3]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename Fragment , typename PredicateIterator >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::load_post_increment (Fragmentfragment,
PredicateIterator pred_it 
)
+
+inline
+
+ +
+
+ +

◆ operator++()

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + +
CUTLASS_DEVICE ZipTileIterator& cutlass::ZipTileIterator< First_, Second_ >::operator++ ()
+
+inline
+
+ +
+
+ +

◆ operator+=() [1/2]

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_DEVICE ZipTileIterator& cutlass::ZipTileIterator< First_, Second_ >::operator+= (int count)
+
+inline
+
+ +
+
+ +

◆ operator+=() [2/2]

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_DEVICE ZipTileIterator& cutlass::ZipTileIterator< First_, Second_ >::operator+= (Coord< 3 > const & offset)
+
+inline
+
+ +
+
+ +

◆ operator--()

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + +
CUTLASS_DEVICE ZipTileIterator& cutlass::ZipTileIterator< First_, Second_ >::operator-- ()
+
+inline
+
+ +
+
+ +

◆ operator-=()

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_DEVICE ZipTileIterator& cutlass::ZipTileIterator< First_, Second_ >::operator-= (int count)
+
+inline
+
+ +
+
+ +

◆ store() [1/3]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename Fragment >
+ + + + + +
+ + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::store (Fragment const & fragment) const
+
+inline
+
+ +
+
+ +

◆ store() [2/3]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename Fragment >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::store (Fragment const & fragment,
Coord< 4 > const & offset 
) const
+
+inline
+
+ +
+
+ +

◆ store() [3/3]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename Fragment , typename PredicateIterator >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::store (Fragment const & fragment,
PredicateIterator pred_it 
) const
+
+inline
+
+ +
+
+ +

◆ store_post_increment() [1/3]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename Fragment >
+ + + + + +
+ + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::store_post_increment (Fragment const & fragment)
+
+inline
+
+ +
+
+ +

◆ store_post_increment() [2/3]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename Fragment >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::store_post_increment (Fragment const & fragment,
Coord< 4 > const & offset 
)
+
+inline
+
+ +
+
+ +

◆ store_post_increment() [3/3]

+ +
+
+
+template<typename First_ , typename Second_ >
+
+template<typename Fragment , typename PredicateIterator >
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_DEVICE void cutlass::ZipTileIterator< First_, Second_ >::store_post_increment (Fragment const & fragment,
PredicateIterator pred_it 
)
+
+inline
+
+ +
+
+

Member Data Documentation

+ +

◆ first

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + +
First cutlass::ZipTileIterator< First_, Second_ >::first
+
+ +
+
+ +

◆ second

+ +
+
+
+template<typename First_ , typename Second_ >
+ + + + +
Second cutlass::ZipTileIterator< First_, Second_ >::second
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/classcutlass_1_1detail_1_1ScalarOrPointer-members.html b/docs/classcutlass_1_1detail_1_1ScalarOrPointer-members.html new file mode 100644 index 000000000..8da714665 --- /dev/null +++ b/docs/classcutlass_1_1detail_1_1ScalarOrPointer-members.html @@ -0,0 +1,101 @@ + + + + + + + +Cutlass: Member List + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+
+
cutlass::detail::ScalarOrPointer< Scalar_ > Member List
+
+ + + + + diff --git a/docs/classcutlass_1_1detail_1_1ScalarOrPointer.html b/docs/classcutlass_1_1detail_1_1ScalarOrPointer.html new file mode 100644 index 000000000..6a28c38f8 --- /dev/null +++ b/docs/classcutlass_1_1detail_1_1ScalarOrPointer.html @@ -0,0 +1,434 @@ + + + + + + + +Cutlass: cutlass::detail::ScalarOrPointer< Scalar_ > Class Template Reference + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+ +
+
cutlass::detail::ScalarOrPointer< Scalar_ > Class Template Reference
+
+
+ +

#include <scalar_or_pointer.h>

+ + + + + +

+Public Types

typedef Scalar_ Scalar
 Underlying scalar type. More...
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

CUTLASS_HOST_DEVICE ScalarOrPointer ()
 Default ctor. More...
 
CUTLASS_HOST_DEVICE ScalarOrPointer (Scalar const &val)
 Object behaves as a scalar. More...
 
CUTLASS_HOST_DEVICE ScalarOrPointer (Scalar const *ptr_)
 Object behaves as a scalar. More...
 
CUTLASS_HOST_DEVICE bool is_pointer () const
 Returns true if is pointer. More...
 
CUTLASS_HOST_DEVICE Scalar const * get_ptr () const
 Gets the pointer value. More...
 
CUTLASS_HOST_DEVICE Scalar get_scalar () const
 Gets the pointer value. More...
 
CUTLASS_HOST_DEVICE ScalarOrPointeroperator= (Scalar const &scalar_)
 Assigns to a scalar and sets pointer to nullptr. More...
 
CUTLASS_HOST_DEVICE ScalarOrPointeroperator= (Scalar const *ptr_)
 Assigns to a pointer value. More...
 
CUTLASS_HOST_DEVICE Scalar get () const
 Access the element. More...
 
CUTLASS_HOST_DEVICE operator Scalar () const
 Accesses the element. More...
 
+

Detailed Description

+

template<typename Scalar_>
+class cutlass::detail::ScalarOrPointer< Scalar_ >

+ +

Helper class defines an object which operates as either a scalar or a pointer. If the pointer is non-null, it is dereferenced when the object is accessed.

+

Member Typedef Documentation

+ +

◆ Scalar

+ +
+
+
+template<typename Scalar_>
+ + + + +
typedef Scalar_ cutlass::detail::ScalarOrPointer< Scalar_ >::Scalar
+
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ ScalarOrPointer() [1/3]

+ +
+
+
+template<typename Scalar_>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE cutlass::detail::ScalarOrPointer< Scalar_ >::ScalarOrPointer ()
+
+inline
+
+ +
+
+ +

◆ ScalarOrPointer() [2/3]

+ +
+
+
+template<typename Scalar_>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::detail::ScalarOrPointer< Scalar_ >::ScalarOrPointer (Scalar const & val)
+
+inline
+
+ +
+
+ +

◆ ScalarOrPointer() [3/3]

+ +
+
+
+template<typename Scalar_>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::detail::ScalarOrPointer< Scalar_ >::ScalarOrPointer (Scalar const * ptr_)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ get()

+ +
+
+
+template<typename Scalar_>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE Scalar cutlass::detail::ScalarOrPointer< Scalar_ >::get () const
+
+inline
+
+ +
+
+ +

◆ get_ptr()

+ +
+
+
+template<typename Scalar_>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE Scalar const* cutlass::detail::ScalarOrPointer< Scalar_ >::get_ptr () const
+
+inline
+
+ +
+
+ +

◆ get_scalar()

+ +
+
+
+template<typename Scalar_>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE Scalar cutlass::detail::ScalarOrPointer< Scalar_ >::get_scalar () const
+
+inline
+
+ +
+
+ +

◆ is_pointer()

+ +
+
+
+template<typename Scalar_>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE bool cutlass::detail::ScalarOrPointer< Scalar_ >::is_pointer () const
+
+inline
+
+ +
+
+ +

◆ operator Scalar()

+ +
+
+
+template<typename Scalar_>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE cutlass::detail::ScalarOrPointer< Scalar_ >::operator Scalar () const
+
+inline
+
+ +
+
+ +

◆ operator=() [1/2]

+ +
+
+
+template<typename Scalar_>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ScalarOrPointer& cutlass::detail::ScalarOrPointer< Scalar_ >::operator= (Scalar const & scalar_)
+
+inline
+
+ +
+
+ +

◆ operator=() [2/2]

+ +
+
+
+template<typename Scalar_>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE ScalarOrPointer& cutlass::detail::ScalarOrPointer< Scalar_ >::operator= (Scalar const * ptr_)
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/classcutlass_1_1gemm_1_1LinearScalingDevicePtr_1_1Params-members.html b/docs/classcutlass_1_1gemm_1_1LinearScalingDevicePtr_1_1Params-members.html new file mode 100644 index 000000000..323b1406c --- /dev/null +++ b/docs/classcutlass_1_1gemm_1_1LinearScalingDevicePtr_1_1Params-members.html @@ -0,0 +1,98 @@ + + + + + + + +Cutlass: Member List + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+
+
cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::Params Member List
+
+ + + + + diff --git a/docs/classcutlass_1_1gemm_1_1LinearScalingDevicePtr_1_1Params.html b/docs/classcutlass_1_1gemm_1_1LinearScalingDevicePtr_1_1Params.html new file mode 100644 index 000000000..5fc5d05e3 --- /dev/null +++ b/docs/classcutlass_1_1gemm_1_1LinearScalingDevicePtr_1_1Params.html @@ -0,0 +1,389 @@ + + + + + + + +Cutlass: cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::Params Class Reference + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+ +
+
cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::Params Class Reference
+
+
+ +

The parameters. +

+ +

#include <linear_scaling_device_ptr.h>

+ + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

CUTLASS_HOST_DEVICE Params ()
 
CUTLASS_HOST_DEVICE Params (Scalar alpha, Scalar beta)
 
CUTLASS_HOST_DEVICE Params (Scalar const *alpha_ptr, Scalar const *beta_ptr)
 
CUTLASS_HOST_DEVICE int initialize (Scalar alpha, Scalar beta)
 Initialize the parameters. More...
 
CUTLASS_HOST_DEVICE int initialize (Scalar const *alpha, Scalar const *beta)
 Initialize the parameters. More...
 
template<typename GemmDesc_ >
CUTLASS_HOST_DEVICE int initialize (GemmDesc_ const &desc)
 Initialize the parameters. More...
 
CUTLASS_HOST_DEVICE Scalar alpha () const
 Gets the alpha scalar. More...
 
CUTLASS_HOST_DEVICE Scalar beta () const
 Gets the beta scalar. More...
 
+

Constructor & Destructor Documentation

+ +

◆ Params() [1/3]

+ +
+
+
+template<typename Scalar_ , typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_>>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::Params::Params ()
+
+inline
+
+ +
+
+ +

◆ Params() [2/3]

+ +
+
+
+template<typename Scalar_ , typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_>>
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::Params::Params (Scalar alpha,
Scalar beta 
)
+
+inline
+
+ +
+
+ +

◆ Params() [3/3]

+ +
+
+
+template<typename Scalar_ , typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_>>
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::Params::Params (Scalar const * alpha_ptr,
Scalar const * beta_ptr 
)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ alpha()

+ +
+
+
+template<typename Scalar_ , typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_>>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE Scalar cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::Params::alpha () const
+
+inline
+
+ +
+
+ +

◆ beta()

+ +
+
+
+template<typename Scalar_ , typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_>>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE Scalar cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::Params::beta () const
+
+inline
+
+ +
+
+ +

◆ initialize() [1/3]

+ +
+
+
+template<typename Scalar_ , typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_>>
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE int cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::Params::initialize (Scalar alpha,
Scalar beta 
)
+
+inline
+
+ +
+
+ +

◆ initialize() [2/3]

+ +
+
+
+template<typename Scalar_ , typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_>>
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE int cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::Params::initialize (Scalar const * alpha,
Scalar const * beta 
)
+
+inline
+
+ +
+
+ +

◆ initialize() [3/3]

+ +
+
+
+template<typename Scalar_ , typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_>>
+
+template<typename GemmDesc_ >
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE int cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::Params::initialize (GemmDesc_ const & desc)
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/classcutlass_1_1platform_1_1complex-members.html b/docs/classcutlass_1_1platform_1_1complex-members.html new file mode 100644 index 000000000..3e19742e0 --- /dev/null +++ b/docs/classcutlass_1_1platform_1_1complex-members.html @@ -0,0 +1,100 @@ + + + + + + + +Cutlass: Member List + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+
+
cutlass::platform::complex< T > Member List
+
+ + + + + diff --git a/docs/classcutlass_1_1platform_1_1complex.html b/docs/classcutlass_1_1platform_1_1complex.html new file mode 100644 index 000000000..672fef7e9 --- /dev/null +++ b/docs/classcutlass_1_1platform_1_1complex.html @@ -0,0 +1,413 @@ + + + + + + + +Cutlass: cutlass::platform::complex< T > Class Template Reference + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+ +
+
cutlass::platform::complex< T > Class Template Reference
+
+
+ +

#include <complex.h>

+ + + + + +

+Public Types

typedef T value_type
 Type alias for scalar type. More...
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Public Member Functions

CUTLASS_HOST_DEVICE complex (T r=T(0), T i=T(0))
 Constructor. More...
 
CUTLASS_HOST_DEVICE complex (cuFloatComplex const &z)
 Conversion from cuFloatComplex. More...
 
CUTLASS_HOST_DEVICE complex (cuDoubleComplex const &z)
 Conversion from cuDoubleComplex. More...
 
CUTLASS_HOST_DEVICE T const & real () const
 Accesses the real part of the complex number. More...
 
CUTLASS_HOST_DEVICE T & real ()
 Accesses the real part of the complex number. More...
 
CUTLASS_HOST_DEVICE T const & imag () const
 Accesses the imaginary part of the complex number. More...
 
CUTLASS_HOST_DEVICE T & imag ()
 Accesses the imaginary part of the complex number. More...
 
CUTLASS_HOST_DEVICE operator cuFloatComplex () const
 Converts to cuFloatComplex. More...
 
CUTLASS_HOST_DEVICE operator cuDoubleComplex () const
 Converts to cuDoubleComplex. More...
 
+

Detailed Description

+

template<typename T>
+class cutlass::platform::complex< T >

+ +

Class for representing and manipulating complex numbers with conversions from built-in CUDA complex types.

+

Member Typedef Documentation

+ +

◆ value_type

+ +
+
+
+template<typename T>
+ + + + +
typedef T cutlass::platform::complex< T >::value_type
+
+ +
+
+

Constructor & Destructor Documentation

+ +

◆ complex() [1/3]

+ +
+
+
+template<typename T>
+ + + + + +
+ + + + + + + + + + + + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::platform::complex< T >::complex (r = T(0),
i = T(0) 
)
+
+inline
+
+ +
+
+ +

◆ complex() [2/3]

+ +
+
+
+template<typename T>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::platform::complex< T >::complex (cuFloatComplex const & z)
+
+inline
+
+ +
+
+ +

◆ complex() [3/3]

+ +
+
+
+template<typename T>
+ + + + + +
+ + + + + + + + +
CUTLASS_HOST_DEVICE cutlass::platform::complex< T >::complex (cuDoubleComplex const & z)
+
+inline
+
+ +
+
+

Member Function Documentation

+ +

◆ imag() [1/2]

+ +
+
+
+template<typename T>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE T const& cutlass::platform::complex< T >::imag () const
+
+inline
+
+ +
+
+ +

◆ imag() [2/2]

+ +
+
+
+template<typename T>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE T& cutlass::platform::complex< T >::imag ()
+
+inline
+
+ +
+
+ +

◆ operator cuDoubleComplex()

+ +
+
+
+template<typename T>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE cutlass::platform::complex< T >::operator cuDoubleComplex () const
+
+inline
+
+ +
+
+ +

◆ operator cuFloatComplex()

+ +
+
+
+template<typename T>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE cutlass::platform::complex< T >::operator cuFloatComplex () const
+
+inline
+
+ +
+
+ +

◆ real() [1/2]

+ +
+
+
+template<typename T>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE T const& cutlass::platform::complex< T >::real () const
+
+inline
+
+ +
+
+ +

◆ real() [2/2]

+ +
+
+
+template<typename T>
+ + + + + +
+ + + + + + + +
CUTLASS_HOST_DEVICE T& cutlass::platform::complex< T >::real ()
+
+inline
+
+ +
+
+
The documentation for this class was generated from the following file: +
+ + + + diff --git a/docs/classcutlass_1_1platform_1_1unique__ptr-members.html b/docs/classcutlass_1_1platform_1_1unique__ptr-members.html index 696f47884..1242de683 100644 --- a/docs/classcutlass_1_1platform_1_1unique__ptr-members.html +++ b/docs/classcutlass_1_1platform_1_1unique__ptr-members.html @@ -98,7 +98,7 @@ $(function() {
diff --git a/docs/classcutlass_1_1platform_1_1unique__ptr.html b/docs/classcutlass_1_1platform_1_1unique__ptr.html index cf455f2e5..625e790b8 100644 --- a/docs/classcutlass_1_1platform_1_1unique__ptr.html +++ b/docs/classcutlass_1_1platform_1_1unique__ptr.html @@ -546,7 +546,7 @@ template<class T, class Deleter = default_delete<T>>
diff --git a/docs/classes.html b/docs/classes.html index 9896653f6..6a517312c 100644 --- a/docs/classes.html +++ b/docs/classes.html @@ -72,100 +72,116 @@ $(function() {
Class Index
-
a | b | c | d | e | f | g | h | i | l | m | n | p | r | s | t | u | v | w
+
a | b | c | d | e | f | g | h | i | k | l | m | n | p | r | s | t | u | v | w | z
- - - - - - - - - - - - - - - - + + + + + + + + + - - - + + + + + + + + + + - - + + + + - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + - - - + + + - - - - + + + + + + + - - - - - - + + + + + + + - - - - - - - + + + + + + + + + + + + +
  a  
-
FragmentMultiplyAdd (cutlass::gemm)   IgemmEpilogueScalar (cutlass::gemm)   Load< Scalar_, Lanes_, Memory_, true, 8 > (cutlass)   GlobalLoadStreamBase::SharedStorage (cutlass::gemm)   
FragmentMultiplyAdd< half > (cutlass::gemm)   IgemmEpilogueScalar< int > (cutlass::gemm)   log2_down (cutlass)   SimplifiedGemmEpilogueTraits (cutlass::gemm)   
aligned_chunk (cutlass::platform)   FragmentStore (cutlass)   IgemmEpilogueTraits (cutlass::gemm)   log2_down< N, 1, Count > (cutlass)   SimplifiedGemmTraits (cutlass::gemm)   
aligned_storage (cutlass::platform)   FragmentStore< IteratorFragment::kScalar, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride > (cutlass)   IgemmEpilogueTraitsHelper (cutlass::gemm)   log2_up (cutlass)   SimplifiedGemmTraitsHelper (cutlass::gemm)   
AlignedStruct (cutlass)   FragmentStore< IteratorFragment::kWmmaMatrix, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride > (cutlass)   IgemmFloatToInt8Converter (cutlass::gemm)   log2_up< N, 1, Count > (cutlass)   sqrt_est (cutlass)   
alignment_of (cutlass::platform)   
  g  
-
IgemmGlobalLoadTransformer (cutlass::gemm)   
  m  
-
StorageType (cutlass)   
alignment_of< const value_t > (cutlass::platform)   IgemmGlobalLoadTransformer< Fragment< int8_t, kElements_ >, float > (cutlass::gemm)   StorageType< 1 > (cutlass)   
alignment_of< const volatile value_t > (cutlass::platform)   Gemm (cutlass::gemm)   IgemmGlobalStoreTransformer (cutlass::gemm)   GemmTraits::MainLoopSharedStorage (cutlass::gemm)   StorageType< 2 > (cutlass)   
alignment_of< double2 > (cutlass::platform)   GemmConfig (cutlass::gemm)   IgemmGlobalStoreTransformer< float, Fragment< int8_t, kElements_ > > (cutlass::gemm)   MatrixLayout (cutlass)   StorageType< 4 > (cutlass)   
alignment_of< double4 > (cutlass::platform)   GemmDesc (cutlass::gemm)   IgemmInt8ToFloatConverter (cutlass::gemm)   MemorySpace (cutlass)   Store (cutlass)   
alignment_of< float4 > (cutlass::platform)   GemmEpilogue (cutlass::gemm)   IgemmSharedStoreTransformer (cutlass::gemm)   
  n  
-
Store< double, 2, Memory_, true, 16 > (cutlass)   
alignment_of< int4 > (cutlass::platform)   GemmEpilogueTraits (cutlass::gemm)   IgemmSwizzle (cutlass::gemm)   Store< Scalar_, Lanes_, Memory_, true, 16 > (cutlass)   
alignment_of< long4 > (cutlass::platform)   GemmEpilogueTraitsHelper (cutlass::gemm)   IgemmTileTraitsHelperA (cutlass::gemm)   nullptr_t (cutlass::platform)   Store< Scalar_, Lanes_, Memory_, true, 4 > (cutlass)   
alignment_of< longlong2 > (cutlass::platform)   GemmGlobalIteratorAb (cutlass::gemm)   IgemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_ > (cutlass::gemm)   
  p  
-
Store< Scalar_, Lanes_, Memory_, true, 8 > (cutlass)   
alignment_of< longlong4 > (cutlass::platform)   GemmGlobalIteratorCd (cutlass::gemm)   IgemmTileTraitsHelperB (cutlass::gemm)   GemmTraits::StreamSharedStorage (cutlass::gemm)   
alignment_of< uint4 > (cutlass::platform)   GemmGlobalTileCdTraits (cutlass::gemm)   IgemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_ > (cutlass::gemm)   alignment_of::pad (cutlass::platform)   GemmEpilogueTraits::StreamSharedStorage (cutlass::gemm)   
alignment_of< ulong4 > (cutlass::platform)   GemmGlobalTileTraits (cutlass::gemm)   IgemmTraits (cutlass::gemm)   WmmaGemmGlobalIteratorCd::Params (cutlass::gemm)   
  t  
+
GemmConfig (cutlass::gemm)   IgemmTraitsHelper (cutlass::gemm)   LinearScalingDevicePtr::Params (cutlass::gemm)   Store< double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16 > (cutlass)   
GemmCoord (cutlass::gemm)   IgemmTransformerA (cutlass::gemm)   GlobalLoadStream::Params (cutlass::gemm)   Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2 > (cutlass)   
aligned_chunk (cutlass::platform)   GemmDesc (cutlass::gemm)   IgemmTransformerA< MatrixLayout::kColumnMajor, Iterator_ > (cutlass::gemm)   SharedStreamPair::Params (cutlass::gemm)   Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16 > (cutlass)   
aligned_storage (cutlass::platform)   GemmEpilogue (cutlass::gemm)   IgemmTransformerA< MatrixLayout::kRowMajor, Iterator_ > (cutlass::gemm)   WmmaGemmGlobalIteratorCd::Params (cutlass::gemm)   Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4 > (cutlass)   
AlignedStruct (cutlass)   GemmEpilogueTraits (cutlass::gemm)   IgemmTransformerB (cutlass::gemm)   ZipTileIterator::Params (cutlass)   Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8 > (cutlass)   
alignment_of (cutlass::platform)   GemmEpilogueTraitsHelper (cutlass::gemm)   IgemmTransformerB< MatrixLayout::kColumnMajor, Iterator_ > (cutlass::gemm)   GemmTraits::Params (cutlass::gemm)   Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size > (cutlass)   
alignment_of< const value_t > (cutlass::platform)   GemmGlobalIteratorAb (cutlass::gemm)   IgemmTransformerB< MatrixLayout::kRowMajor, Iterator_ > (cutlass::gemm)   LinearScaling::Params (cutlass::gemm)   GemmEpilogueTraits::StreamSharedStorage (cutlass::gemm)   
alignment_of< const volatile value_t > (cutlass::platform)   GemmGlobalIteratorCd (cutlass::gemm)   int4_t (cutlass)   GemmGlobalIteratorAb::Params (cutlass::gemm)   TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::StrideVector (cutlass)   
alignment_of< double2 > (cutlass::platform)   GemmGlobalTileCdTraits (cutlass::gemm)   integral_constant (cutlass::platform)   GlobalLoadStreamPair::Params (cutlass::gemm)   swizzleDirection (cutlass::gemm)   
alignment_of< double4 > (cutlass::platform)   GemmGlobalTileTraits (cutlass::gemm)   is_arithmetic (cutlass::platform)   GemmGlobalIteratorCd::Params (cutlass::gemm)   
  t  
alignment_of< ulonglong2 > (cutlass::platform)   GemmMultiplicandTraits (cutlass::gemm)   IgemmTraitsHelper (cutlass::gemm)   GemmTraits::Params (cutlass::gemm)   
alignment_of< ulonglong4 > (cutlass::platform)   GemmOperand (cutlass)   IgemmTransformerA (cutlass::gemm)   GlobalLoadStreamBase::Params (cutlass::gemm)   TensorRef (cutlass)   
alignment_of< volatile value_t > (cutlass::platform)   GemmOperandTraitsAb (cutlass::gemm)   IgemmTransformerA< MatrixLayout::kColumnMajor, Iterator_ > (cutlass::gemm)   TileIteratorBase::Params (cutlass)   TensorView (cutlass)   
alignment_of< float4 > (cutlass::platform)   GemmMultiplicandTraits (cutlass::gemm)   is_base_of (cutlass::platform)   GemmEpilogueTraits::Params (cutlass::gemm)   
alignment_of< int4 > (cutlass::platform)   GemmOperand (cutlass)   is_base_of_helper (cutlass::platform)   TileIteratorBase::Params (cutlass)   TensorRef (cutlass)   
alignment_of< long4 > (cutlass::platform)   GemmOperandTraitsAb (cutlass::gemm)   is_floating_point (cutlass::platform)   TileLoadIterator::Params (cutlass)   TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ > (cutlass)   
alignment_of< longlong2 > (cutlass::platform)   GemmSharedLoadTileATraits (cutlass::gemm)   is_fundamental (cutlass::platform)   TileStoreIterator::Params (cutlass)   TensorRefArray (cutlass)   
alignment_of< longlong4 > (cutlass::platform)   GemmSharedLoadTileBTraits (cutlass::gemm)   is_integral (cutlass::platform)   TileLoadStream::Params (cutlass)   TensorRefBatchStrided (cutlass)   
alignment_of< uint4 > (cutlass::platform)   GemmSharedLoadTileDTraits (cutlass::gemm)   is_integral< char > (cutlass::platform)   TileStoreStream::Params (cutlass)   TensorView (cutlass)   
alignment_of< ulong4 > (cutlass::platform)   GemmSharedStoreTileAbTraits (cutlass::gemm)   is_integral< const T > (cutlass::platform)   SharedLoadStream::Params (cutlass::gemm)   ThreadMultiplyAdd (cutlass::gemm)   
alignment_of< ulonglong2 > (cutlass::platform)   GemmSharedStoreTileDTraits (cutlass::gemm)   is_integral< const volatile T > (cutlass::platform)   plus (cutlass::platform)   ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, float > (cutlass::gemm)   
alignment_of< ulonglong4 > (cutlass::platform)   GemmSharedStoreWithSkewTileAbTraits (cutlass::gemm)   is_integral< int > (cutlass::platform)   PredicatedTileLoadStream (cutlass)   ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half > (cutlass::gemm)   
alignment_of< volatile value_t > (cutlass::platform)   GemmTileTraitsHelperA (cutlass::gemm)   is_integral< long > (cutlass::platform)   PredicatedTileStoreStream (cutlass)   ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int > (cutlass::gemm)   
  b  
-
GemmSharedLoadTileATraits (cutlass::gemm)   IgemmTransformerA< MatrixLayout::kRowMajor, Iterator_ > (cutlass::gemm)   GemmGlobalIteratorCd::Params (cutlass::gemm)   ThreadMultiplyAdd (cutlass::gemm)   
GemmSharedLoadTileBTraits (cutlass::gemm)   IgemmTransformerB (cutlass::gemm)   TileLoadIterator::Params (cutlass)   ThreadMultiplyAdd< AccumulatorsPerThread_, ThreadsPerWarp_, half, half, half > (cutlass::gemm)   
bool_constant (cutlass::platform)   GemmSharedLoadTileDTraits (cutlass::gemm)   IgemmTransformerB< MatrixLayout::kColumnMajor, Iterator_ > (cutlass::gemm)   TileStoreIterator::Params (cutlass)   ThreadMultiplyAdd< AccumulatorsPerThread_, ThreadsPerWarp_, int8_t, int8_t, int > (cutlass::gemm)   
GemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_ > (cutlass::gemm)   is_integral< long long > (cutlass::platform)   PredicateTileAdapter (cutlass)   GemmSharedStoreTileAbTraits::ThreadOffset (cutlass::gemm)   
GemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_ > (cutlass::gemm)   is_integral< short > (cutlass::platform)   TileLoadStream::PredicateVector (cutlass)   WmmaGemmGlobalIteratorCdTraits::ThreadOffset (cutlass::gemm)   
bin1_t (cutlass)   GemmTileTraitsHelperB (cutlass::gemm)   is_integral< signed char > (cutlass::platform)   PredicateVector (cutlass)   GemmGlobalTileCdTraits::ThreadOffset (cutlass::gemm)   
bool_constant (cutlass::platform)   GemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ > (cutlass::gemm)   is_integral< unsigned char > (cutlass::platform)   TileStoreStream::PredicateVector (cutlass)   GemmSharedLoadTileATraits::ThreadOffset (cutlass::gemm)   
  c  
-
GemmSharedStoreTileAbTraits (cutlass::gemm)   IgemmTransformerB< MatrixLayout::kRowMajor, Iterator_ > (cutlass::gemm)   GemmEpilogueTraits::Params (cutlass::gemm)   GemmSharedLoadTileBTraits::ThreadOffset (cutlass::gemm)   
GemmSharedStoreTileDTraits (cutlass::gemm)   integral_constant (cutlass::platform)   Gemm::Params (cutlass::gemm)   GemmGlobalTileCdTraits::ThreadOffset (cutlass::gemm)   
ClearAccumulators (cutlass::gemm)   GemmSharedStoreWithSkewTileAbTraits (cutlass::gemm)   is_arithmetic (cutlass::platform)   SharedLoadStream::Params (cutlass::gemm)   IgemmContiguousGlobalTileTraits::ThreadOffset (cutlass::gemm)   
ComputeOffsetFromShape (cutlass)   GemmTileTraitsHelperA (cutlass::gemm)   is_base_of (cutlass::platform)   LinearScaling::Params (cutlass::gemm)   GemmGlobalTileTraits::ThreadOffset (cutlass::gemm)   
ComputeOffsetFromShape< Shape< 1, kSh_, kSw_, 1 > > (cutlass)   GemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_ > (cutlass::gemm)   is_base_of_helper (cutlass::platform)   GemmGlobalIteratorAb::Params (cutlass::gemm)   GemmSharedLoadTileDTraits::ThreadOffset (cutlass::gemm)   
ComputeOffsetFromShape< Shape< 1, kSh_, kSw_, kSc_ > > (cutlass)   GemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_ > (cutlass::gemm)   is_floating_point (cutlass::platform)   plus (cutlass::platform)   GemmSharedLoadTileATraits::ThreadOffset (cutlass::gemm)   
ComputeOffsetFromStrides (cutlass)   GemmTileTraitsHelperB (cutlass::gemm)   is_fundamental (cutlass::platform)   PredicateTileAdapter (cutlass)   GemmSharedStoreTileDTraits::ThreadOffset (cutlass::gemm)   
ComputeOffsetFromStrides< Shape< 1, S_h_, S_w_, 1 > > (cutlass)   GemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ > (cutlass::gemm)   is_integral (cutlass::platform)   PredicateVector (cutlass)   HgemmCrosswiseGlobalTileTraits::ThreadOffset (cutlass::gemm)   
ComputeOffsetFromStrides< Shape< 1, S_h_, S_w_, S_c_ > > (cutlass)   GemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_ > (cutlass::gemm)   is_integral< char > (cutlass::platform)   ProjectOperand (cutlass::gemm)   GemmSharedStoreTileAbTraits::ThreadOffset (cutlass::gemm)   
ComputeThreadOffsetFromStrides (cutlass)   GemmTraits (cutlass::gemm)   is_integral< const T > (cutlass::platform)   ProjectOperand< GemmOperand::kA, Kstrided > (cutlass::gemm)   TileTraitsWarpRake::ThreadOffset (cutlass)   
ComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, 1 >, Shape< 1, S_h_, S_w_, 1 > > (cutlass)   GetExtent (cutlass::gemm)   is_integral< const volatile T > (cutlass::platform)   ProjectOperand< GemmOperand::kB, Kstrided > (cutlass::gemm)   GemmSharedStoreWithSkewTileAbTraits::ThreadOffset (cutlass::gemm)   
ComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, T_c_ >, Shape< 1, S_h_, S_w_, S_c_ > > (cutlass)   GetExtent< GemmOperand::kA, Tile_ > (cutlass::gemm)   is_integral< int > (cutlass::platform)   ProjectOperand< GemmOperand::kC, true > (cutlass::gemm)   WmmaGemmGlobalIteratorCdTraits::ThreadOffset (cutlass::gemm)   
conditional (cutlass::platform)   GetExtent< GemmOperand::kB, Tile_ > (cutlass::gemm)   is_integral< long > (cutlass::platform)   ProjectOperand< GemmOperand::kD, true > (cutlass::gemm)   TiledThreadOffset (cutlass)   
conditional< false, T, F > (cutlass::platform)   GemmTraits::GlobalLoadStream (cutlass::gemm)   is_integral< long long > (cutlass::platform)   
  r  
-
TileIteratorBase (cutlass)   
PredicateVector::ConstIterator (cutlass)   GlobalLoadStream (cutlass::gemm)   is_integral< short > (cutlass::platform)   TileLoadIterator (cutlass)   
ConstPredicateTileAdapter (cutlass)   GlobalLoadStreamBase (cutlass::gemm)   is_integral< signed char > (cutlass::platform)   remove_const (cutlass::platform)   TileStoreIterator (cutlass)   
Convert (cutlass)   greater (cutlass::platform)   is_integral< unsigned char > (cutlass::platform)   remove_const< const T > (cutlass::platform)   TileTraits (cutlass)   
Convert< Fragment< InputScalar_, kScalars_ >, Fragment< OutputScalar_, kScalars_ > > (cutlass)   
  h  
-
is_integral< unsigned int > (cutlass::platform)   remove_cv (cutlass::platform)   TileTraitsContiguousMajor (cutlass)   
Coord (cutlass)   is_integral< unsigned long > (cutlass::platform)   remove_volatile (cutlass::platform)   TileTraitsStandard (cutlass)   
Copy (cutlass)   HgemmConfig (cutlass::gemm)   is_integral< unsigned long long > (cutlass::platform)   remove_volatile< volatile T > (cutlass::platform)   TileTraitsStrideMajor (cutlass)   
GemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_ > (cutlass::gemm)   is_integral< unsigned int > (cutlass::platform)   ProjectOperand (cutlass::gemm)   GemmSharedStoreWithSkewTileAbTraits::ThreadOffset (cutlass::gemm)   
GemmTraits (cutlass::gemm)   is_integral< unsigned long > (cutlass::platform)   ProjectOperand< GemmOperand::kA, Kstrided > (cutlass::gemm)   IgemmGlobalTileTraits::ThreadOffset (cutlass::gemm)   
ClearAccumulators (cutlass::gemm)   GetExtent (cutlass::gemm)   is_integral< unsigned long long > (cutlass::platform)   ProjectOperand< GemmOperand::kB, Kstrided > (cutlass::gemm)   GemmSharedLoadTileBTraits::ThreadOffset (cutlass::gemm)   
MatrixLayout::ColumnMajor (cutlass)   GetExtent< GemmOperand::kA, Tile_ > (cutlass::gemm)   is_integral< unsigned short > (cutlass::platform)   ProjectOperand< GemmOperand::kC, true > (cutlass::gemm)   GemmGlobalTileTraits::ThreadOffset (cutlass::gemm)   
MatrixLayout::ColumnMajorBlockLinear (cutlass)   GetExtent< GemmOperand::kB, Tile_ > (cutlass::gemm)   is_integral< volatile T > (cutlass::platform)   ProjectOperand< GemmOperand::kD, true > (cutlass::gemm)   GemmSharedLoadTileDTraits::ThreadOffset (cutlass::gemm)   
ColumnMajorBlockSwizzle (cutlass::gemm)   GlobalLoadStream (cutlass::gemm)   is_pointer (cutlass::platform)   
  r  
+
TileTraitsWarpRake::ThreadOffset (cutlass)   
MatrixLayout::ColumnMajorInterleaved (cutlass)   GlobalLoadStreamPair (cutlass::gemm)   is_pointer_helper (cutlass::platform)   GemmSharedStoreTileDTraits::ThreadOffset (cutlass::gemm)   
complex (cutlass::platform)   greater (cutlass::platform)   is_pointer_helper< T * > (cutlass::platform)   RegularTilePredicateFunctor (cutlass)   HgemmCrosswiseGlobalTileTraits::ThreadOffset (cutlass::gemm)   
ComputeOffsetFromShape (cutlass)   
  h  
+
is_pow2 (cutlass)   remove_const (cutlass::platform)   TileAllocation (cutlass)   
ComputeOffsetFromStrides (cutlass)   is_same (cutlass::platform)   remove_const< const T > (cutlass::platform)   TileCoord (cutlass)   
ComputeThreadOffsetFromStrides (cutlass)   HgemmConfig (cutlass::gemm)   is_same< A, A > (cutlass::platform)   remove_cv (cutlass::platform)   TiledThreadOffset (cutlass)   
ComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, 1 >, Shape< 1, S_h_, S_w_, 1 > > (cutlass)   HgemmCrosswiseGlobalTileTraits (cutlass::gemm)   is_trivially_copyable (cutlass::platform)   remove_volatile (cutlass::platform)   TileIteratorBase (cutlass)   
ComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, T_c_ >, Shape< 1, S_h_, S_w_, S_c_ > > (cutlass)   HgemmSwizzle (cutlass::gemm)   is_void (cutlass::platform)   remove_volatile< volatile T > (cutlass::platform)   TileLoadIterator (cutlass)   
conditional (cutlass::platform)   HgemmTileTraitsHelperA (cutlass::gemm)   is_volatile (cutlass::platform)   ReshapeThreads (cutlass::gemm)   TileLoadStream (cutlass)   
conditional< false, T, F > (cutlass::platform)   HgemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_ > (cutlass::gemm)   is_volatile< volatile T > (cutlass::platform)   ReshapeThreads< Tile_, Threads_, true > (cutlass::gemm)   TileStoreIterator (cutlass)   
PredicateVector::ConstIterator (cutlass)   HgemmTileTraitsHelperB (cutlass::gemm)   PredicateVector::Iterator (cutlass)   ReshapeTile (cutlass)   TileStoreStream (cutlass)   
TensorRefBatchStrided::ConstIterator (cutlass)   HgemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ > (cutlass::gemm)   IteratorAdvance (cutlass)   ReshapeTile< Tile_, kAccessSize_, true > (cutlass)   TileTraits (cutlass)   
TensorRefArray::ConstIterator (cutlass)   HgemmTraits (cutlass::gemm)   
  k  
+
MatrixLayout::RowMajor (cutlass)   TileTraitsContiguousMajor (cutlass)   
ConstPredicateTileAdapter (cutlass)   HgemmTraitsHelper (cutlass::gemm)   MatrixLayout::RowMajorBlockLinear (cutlass)   TileTraitsStandard (cutlass)   
MatrixLayout::ContiguousLayout (cutlass)   HgemmTransformerA (cutlass::gemm)   KernelLaunchConfiguration (cutlass)   RowMajorBlockSwizzle (cutlass::gemm)   TileTraitsStrideMajor (cutlass)   
Convert (cutlass)   HgemmTransformerA< MatrixLayout::kColumnMajor, Iterator_ > (cutlass::gemm)   
  l  
+
MatrixLayout::RowMajorInterleaved (cutlass)   TileTraitsWarpRake (cutlass)   
Convert< Fragment< InputScalar_, kScalars_ >, Fragment< OutputScalar_, kScalars_ > > (cutlass)   HgemmTransformerA< MatrixLayout::kRowMajor, Iterator_ > (cutlass::gemm)   
  s  
+
PredicateVector::TrivialIterator (cutlass)   
Coord (cutlass)   HgemmTransformerB (cutlass::gemm)   Launch (cutlass::gemm)   TrivialPredicateTileAdapter (cutlass)   
Copy (cutlass)   HgemmTransformerB< MatrixLayout::kColumnMajor, Iterator_ > (cutlass::gemm)   Launch< Gemm, false > (cutlass::gemm)   ScalarIO (cutlass)   
  u  
+
  d  
-
HgemmCrosswiseGlobalTileTraits (cutlass::gemm)   is_integral< unsigned short > (cutlass::platform)   ReshapeThreads (cutlass::gemm)   TileTraitsWarpRake (cutlass)   
HgemmSwizzle (cutlass::gemm)   is_integral< volatile T > (cutlass::platform)   ReshapeThreads< Tile_, Threads_, true > (cutlass::gemm)   PredicateVector::TrivialIterator (cutlass)   
default_delete (cutlass::platform)   HgemmTileTraitsHelperA (cutlass::gemm)   is_pointer (cutlass::platform)   ReshapeTile (cutlass)   TrivialPredicateTileAdapter (cutlass)   
default_delete< T[]> (cutlass::platform)   HgemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_ > (cutlass::gemm)   is_pointer_helper (cutlass::platform)   ReshapeTile< Tile_, kAccessSize_, true > (cutlass)   
  u  
+
HgemmTransformerB< MatrixLayout::kRowMajor, Iterator_ > (cutlass::gemm)   less (cutlass::platform)   ScalarOrPointer (cutlass::detail)   
  i  
+
LinearScaling (cutlass::gemm)   SgemmConfig (cutlass::gemm)   uint4_t (cutlass)   
DebugType   LinearScalingDevicePtr (cutlass::gemm)   SgemmLBTraits (cutlass::gemm)   unique_ptr (cutlass::platform)   
DebugValue   Identity (cutlass)   Load (cutlass)   SgemmTraits (cutlass::gemm)   
  v  
DgemmConfig (cutlass::gemm)   HgemmTileTraitsHelperB (cutlass::gemm)   is_pointer_helper< T * > (cutlass::platform)   
  s  
-
DgemmTraits (cutlass::gemm)   HgemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ > (cutlass::gemm)   is_pow2 (cutlass)   unique_ptr (cutlass::platform)   
divide_assert (cutlass)   HgemmTraits (cutlass::gemm)   is_same (cutlass::platform)   SgemmConfig (cutlass::gemm)   
  v  
-
is_base_of_helper::dummy (cutlass::platform)   HgemmTraitsHelper (cutlass::gemm)   is_same< A, A > (cutlass::platform)   SgemmTraits (cutlass::gemm)   
default_delete (cutlass::platform)   IdentityBlockSwizzle (cutlass::gemm)   Load< double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16 > (cutlass)   Shape (cutlass)   
default_delete< T[]> (cutlass::platform)   IdentityTensorMapFunc (cutlass)   Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2 > (cutlass)   ShapeAdd (cutlass)   Vector (cutlass)   
DgemmConfig (cutlass::gemm)   IgemmConfig (cutlass::gemm)   Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16 > (cutlass)   ShapeCount (cutlass)   Vector< bin1_t, kLanes_ > (cutlass)   
DgemmTraits (cutlass::gemm)   IgemmConfig< OutputTile_, int8_t, ThreadGemmShape_ > (cutlass::gemm)   Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4 > (cutlass)   ShapeDiv (cutlass)   Vector< half, 1 > (cutlass)   
divide_assert (cutlass)   IgemmEpilogue (cutlass::gemm)   Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8 > (cutlass)   ShapeDivCeiling (cutlass)   Vector< half, kLanes_ > (cutlass)   
is_base_of_helper::dummy (cutlass::platform)   IgemmEpilogue< GemmEpilogueTraits_, true > (cutlass::gemm)   Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size > (cutlass)   ShapeMax (cutlass)   Vector< int4_t, kLanes_ > (cutlass)   
DumpType (cutlass)   IgemmEpilogueScalar (cutlass::gemm)   Load< Vector< bin1_t, 32 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size > (cutlass)   ShapeMin (cutlass)   Vector< uint4_t, kLanes_ > (cutlass)   
  e  
-
HgemmTransformerA (cutlass::gemm)   is_trivially_copyable (cutlass::platform)   Shape (cutlass)   Vector (cutlass)   
HgemmTransformerA< MatrixLayout::kColumnMajor, Iterator_ > (cutlass::gemm)   is_void (cutlass::platform)   ShapeAdd (cutlass)   Vector< half, kLanes_ > (cutlass)   
enable_if (cutlass::platform)   HgemmTransformerA< MatrixLayout::kRowMajor, Iterator_ > (cutlass::gemm)   is_volatile (cutlass::platform)   ShapeCount (cutlass)   Vectorize (cutlass)   
enable_if< false, T > (cutlass::platform)   HgemmTransformerB (cutlass::gemm)   is_volatile< volatile T > (cutlass::platform)   ShapeDiv (cutlass)   Vectorize< Element_, 1 > (cutlass)   
Extent (cutlass)   HgemmTransformerB< MatrixLayout::kColumnMajor, Iterator_ > (cutlass::gemm)   PredicateVector::Iterator (cutlass)   ShapeMax (cutlass)   VectorTraits (cutlass)   
Extent< Vector< T, Lanes > > (cutlass)   HgemmTransformerB< MatrixLayout::kRowMajor, Iterator_ > (cutlass::gemm)   IteratorAdvance (cutlass)   ShapeMin (cutlass)   VectorTraits< Vector< T, Lanes > > (cutlass)   
Extent< Vector< T, Lanes > const > (cutlass)   
  i  
-
IteratorFragment (cutlass)   ShapeMul (cutlass)   VectorTraits< Vector< T, Lanes > const > (cutlass)   
IgemmEpilogueScalar< int > (cutlass::gemm)   Load< Vector< int4_t, 8 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size > (cutlass)   ShapeMul (cutlass)   Vectorize (cutlass)   
IgemmEpilogueTraits (cutlass::gemm)   Load< Vector< uint4_t, 8 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size > (cutlass)   ShapeScale (cutlass)   Vectorize< Vector< bin1_t, 32 >, kLanes_ > (cutlass)   
enable_if (cutlass::platform)   IgemmEpilogueTraitsHelper (cutlass::gemm)   log2_down (cutlass)   ShapeStrides (cutlass)   Vectorize< Vector< int4_t, 8 >, kLanes_ > (cutlass)   
enable_if< false, T > (cutlass::platform)   IgemmFloatToInt8Converter (cutlass::gemm)   log2_down< N, 1, Count > (cutlass)   ShapeSub (cutlass)   Vectorize< Vector< uint4_t, 8 >, kLanes_ > (cutlass)   
Extent (cutlass)   IgemmGlobalIteratorAb (cutlass::gemm)   log2_up (cutlass)   SharedLoadStream (cutlass::gemm)   VectorTraits (cutlass)   
Extent< Vector< T, Lanes > > (cutlass)   IgemmGlobalLoadTransformer (cutlass::gemm)   log2_up< N, 1, Count > (cutlass)   GemmEpilogueTraits::SharedStorage (cutlass::gemm)   VectorTraits< Vector< T, Lanes > > (cutlass)   
Extent< Vector< T, Lanes > const > (cutlass)   IgemmGlobalLoadTransformer< Fragment< int8_t, kElements_ >, float > (cutlass::gemm)   
  m  
+
GlobalLoadStreamPair::SharedStorage (cutlass::gemm)   VectorTraits< Vector< T, Lanes > const > (cutlass)   
  f  
-
  l  
-
ShapeScale (cutlass)   
  w  
+
IgemmGlobalStoreTransformer (cutlass::gemm)   GemmTraits::SharedStorage (cutlass::gemm)   
  w  
Identity (cutlass)   ShapeStrides (cutlass)   
Fragment (cutlass)   IdentityBlockSwizzle (cutlass::gemm)   less (cutlass::platform)   ShapeSub (cutlass)   WmmaGemmGlobalIteratorCd (cutlass::gemm)   
FragmentConstIterator (cutlass)   IgemmConfig (cutlass::gemm)   LinearScaling (cutlass::gemm)   GemmTraits::SharedLoadStream (cutlass::gemm)   WmmaGemmGlobalIteratorCdTraits (cutlass::gemm)   
FragmentIterator (cutlass)   IgemmConfig< OutputTile_, int8_t, AccumulatorsPerThread_ > (cutlass::gemm)   Load (cutlass)   SharedLoadStream (cutlass::gemm)   
FragmentLoad (cutlass)   IgemmContiguousGlobalTileTraits (cutlass::gemm)   Load< double, 2, Memory_, true, 16 > (cutlass)   ClearAccumulators::SharedStorage (cutlass::gemm)   
FragmentLoad< IteratorFragment::kScalar, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride > (cutlass)   IgemmEpilogue (cutlass::gemm)   Load< Scalar_, Lanes_, Memory_, true, 16 > (cutlass)   GemmEpilogueTraits::SharedStorage (cutlass::gemm)   
FragmentLoad< IteratorFragment::kWmmaMatrix, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride > (cutlass)   IgemmEpilogue< GemmEpilogueTraits_, true > (cutlass::gemm)   Load< Scalar_, Lanes_, Memory_, true, 4 > (cutlass)   GemmTraits::SharedStorage (cutlass::gemm)   
IgemmGlobalStoreTransformer< float, Fragment< int8_t, kElements_ > > (cutlass::gemm)   GemmTraits::MainLoopSharedStorage (cutlass::gemm)   GlobalLoadStream::SharedStorage (cutlass::gemm)   
Fp16SgemmConfig (cutlass::gemm)   IgemmGlobalTileTraits (cutlass::gemm)   MatrixCoord (cutlass)   ClearAccumulators::SharedStorage (cutlass::gemm)   WmmaGemmGlobalIteratorCd (cutlass::gemm)   
Fp16SgemmSgemmTraits (cutlass::gemm)   IgemmInt8ToFloatConverter (cutlass::gemm)   MatrixLayout (cutlass)   SharedStreamPair (cutlass::gemm)   WmmaGemmGlobalIteratorCdTraits (cutlass::gemm)   
Fragment (cutlass)   IgemmSharedStoreTransformer (cutlass::gemm)   MatrixTransform (cutlass)   SimplifiedGemmEpilogueTraits (cutlass::gemm)   
  z  
+
FragmentConstIterator (cutlass)   IgemmSwizzle (cutlass::gemm)   Max (cutlass)   SimplifiedGemmTraits (cutlass::gemm)   
FragmentElementType (cutlass)   IgemmTileTraitsHelperA (cutlass::gemm)   MemorySpace (cutlass)   SimplifiedGemmTraitsHelper (cutlass::gemm)   ZipConvert (cutlass)   
FragmentIterator (cutlass)   IgemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_, Index_ > (cutlass::gemm)   Min (cutlass)   sqrt_est (cutlass)   ZipFragment (cutlass)   
FragmentMultiplyAdd (cutlass::gemm)   IgemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_, Index_ > (cutlass::gemm)   
  n  
+
StorageType (cutlass)   ZipTensorRef (cutlass)   
FragmentMultiplyAdd< half, half, true > (cutlass::gemm)   IgemmTileTraitsHelperB (cutlass::gemm)   StorageType< 1 > (cutlass)   ZipTileAllocation (cutlass)   
  g  
+
IgemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_, Index_ > (cutlass::gemm)   nullptr_t (cutlass::platform)   StorageType< 2 > (cutlass)   ZipTileIterator (cutlass)   
IgemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_, Index_ > (cutlass::gemm)   
  p  
+
StorageType< 4 > (cutlass)   
Gemm (cutlass::gemm)   IgemmTraits (cutlass::gemm)   Store (cutlass)   
alignment_of::pad (cutlass::platform)   
-
a | b | c | d | e | f | g | h | i | l | m | n | p | r | s | t | u | v | w
+
a | b | c | d | e | f | g | h | i | k | l | m | n | p | r | s | t | u | v | w | z
diff --git a/docs/clear__accumulators_8h.html b/docs/clear__accumulators_8h.html index b4bd3b39c..cd8f6307a 100644 --- a/docs/clear__accumulators_8h.html +++ b/docs/clear__accumulators_8h.html @@ -82,7 +82,7 @@ $(function() {

Defines abstractions for efficiently clearing accumulator tiles. More...

-
#include <cutlass/vector.h>
+
#include "cutlass/vector.h"

Go to the source code of this file.

@@ -104,7 +104,7 @@ Namespaces diff --git a/docs/clear__accumulators_8h_source.html b/docs/clear__accumulators_8h_source.html index 1a6f517fb..7c0423a5f 100644 --- a/docs/clear__accumulators_8h_source.html +++ b/docs/clear__accumulators_8h_source.html @@ -76,16 +76,17 @@ $(function() {
clear_accumulators.h
-Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <cutlass/vector.h>
31 
32 namespace cutlass {
33 namespace gemm {
34 
36 
37 template <typename Scalar_, int kLanes_ = 1>
40  struct SharedStorage {};
41 
43  CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {}
44 
46  template <typename Fragment_>
47  CUTLASS_DEVICE void clear(Fragment_& fragment) {
48  fragment.clear();
49  }
50 };
51 
53 
54 } // namespace gemm
55 } // namespace cutlass
Definition: convert.h:33
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/vector.h"
31 
32 namespace cutlass {
33 namespace gemm {
34 
36 
37 template <typename Scalar_, int kLanes_ = 1>
40  struct SharedStorage {};
41 
43  CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {}
44 
46  CUTLASS_DEVICE ClearAccumulators() {}
47 
49  template <typename Fragment_>
50  CUTLASS_DEVICE void clear(Fragment_& fragment) {
51  fragment.clear();
52  }
53 };
54 
56 
57 } // namespace gemm
58 } // namespace cutlass
Definition: convert.h:33
Definition: clear_accumulators.h:38
CUTLASS_DEVICE ClearAccumulators(SharedStorage &shared_storage)
Ctor.
Definition: clear_accumulators.h:43
Defines a 1D vector of elements held in the registers of each thread.
-
CUTLASS_DEVICE void clear(Fragment_ &fragment)
Clear the fragment.
Definition: clear_accumulators.h:47
+
CUTLASS_DEVICE void clear(Fragment_ &fragment)
Clear the fragment.
Definition: clear_accumulators.h:50
The shared storage.
Definition: clear_accumulators.h:40
+
CUTLASS_DEVICE ClearAccumulators()
Ctor.
Definition: clear_accumulators.h:46
diff --git a/docs/complex_8h.html b/docs/complex_8h.html new file mode 100644 index 000000000..e94494d21 --- /dev/null +++ b/docs/complex_8h.html @@ -0,0 +1,263 @@ + + + + + + + +Cutlass: complex.h File Reference + + + + + + + + + + +
+
+
+ + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + + +
+ +
+
complex.h File Reference
+
+
+
#include <cuComplex.h>
+#include "cutlass/cutlass.h"
+#include <iosfwd>
+
+

Go to the source code of this file.

+ + + + +

+Classes

class  cutlass::platform::complex< T >
 
+ + + + + +

+Namespaces

 cutlass
 
 cutlass::platform
 
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +

+Functions

CUTLASS_HOST_DEVICE float const & cutlass::platform::real (cuFloatComplex const &z)
 Returns the real part of the complex number. More...
 
CUTLASS_HOST_DEVICE float & cutlass::platform::real (cuFloatComplex &z)
 Returns the real part of the complex number. More...
 
CUTLASS_HOST_DEVICE double const & cutlass::platform::real (cuDoubleComplex const &z)
 Returns the real part of the complex number. More...
 
CUTLASS_HOST_DEVICE double & cutlass::platform::real (cuDoubleComplex &z)
 Returns the real part of the complex number. More...
 
CUTLASS_HOST_DEVICE float const & cutlass::platform::imag (cuFloatComplex const &z)
 Returns the imaginary part of the complex number. More...
 
CUTLASS_HOST_DEVICE float & cutlass::platform::imag (cuFloatComplex &z)
 Returns the imaginary part of the complex number. More...
 
CUTLASS_HOST_DEVICE double const & cutlass::platform::imag (cuDoubleComplex const &z)
 Returns the imaginary part of the complex number. More...
 
CUTLASS_HOST_DEVICE double & cutlass::platform::imag (cuDoubleComplex &z)
 Returns the imaginary part of the complex number. More...
 
template<typename T >
CUTLASS_HOST_DEVICE T const & cutlass::platform::real (complex< T > const &z)
 Returns the real part of the complex number. More...
 
template<typename T >
CUTLASS_HOST_DEVICE T & cutlass::platform::real (complex< T > &z)
 Returns the real part of the complex number. More...
 
template<typename T >
CUTLASS_HOST_DEVICE T const & cutlass::platform::imag (complex< T > const &z)
 Returns the imaginary part of the complex number. More...
 
template<typename T >
CUTLASS_HOST_DEVICE T & cutlass::platform::imag (complex< T > &z)
 Returns the imaginary part of the complex number. More...
 
template<typename T >
std::ostream & cutlass::platform::operator<< (std::ostream &out, complex< T > const &z)
 
template<typename T >
CUTLASS_HOST_DEVICE bool cutlass::platform::operator== (complex< T > const &lhs, complex< T > const &rhs)
 Equality operator. More...
 
template<typename T >
CUTLASS_HOST_DEVICE bool cutlass::platform::operator!= (complex< T > const &lhs, complex< T > const &rhs)
 Inequality operator. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::operator+ (complex< T > const &lhs, complex< T > const &rhs)
 Addition. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::operator- (complex< T > const &lhs, complex< T > const &rhs)
 Subtraction. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::operator* (complex< T > const &lhs, complex< T > const &rhs)
 Multiplication. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::operator* (complex< T > const &lhs, T const &s)
 Scalar Multiplication. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::operator* (T const &s, complex< T > const &rhs)
 Scalar Multiplication. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::operator/ (complex< T > const &lhs, complex< T > const &rhs)
 Division. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::operator/ (complex< T > const &lhs, T const &s)
 Scalar Division. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::operator/ (T const &s, complex< T > const &rhs)
 Scalar divided by complex. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > & cutlass::platform::operator+= (complex< T > &lhs, complex< T > const &rhs)
 Addition. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > & cutlass::platform::operator-= (complex< T > &lhs, complex< T > const &rhs)
 Subtraction. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > & cutlass::platform::operator*= (complex< T > &lhs, complex< T > const &rhs)
 Multiplication. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > & cutlass::platform::operator*= (complex< T > &lhs, T s)
 Scalar multiplication. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > & cutlass::platform::operator/= (complex< T > &lhs, complex< T > const &rhs)
 Division. More...
 
template<typename T >
CUTLASS_HOST_DEVICEcutlass::platform::abs (complex< T > const &z)
 Returns the magnitude of the complex number. More...
 
template<typename T >
CUTLASS_HOST_DEVICEcutlass::platform::arg (complex< T > const &z)
 Returns the magnitude of the complex number. More...
 
template<typename T >
CUTLASS_HOST_DEVICEcutlass::platform::norm (complex< T > const &z)
 Returns the squared magnitude. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::conj (complex< T > const &z)
 Returns the complex conjugate. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::proj (complex< T > const &z)
 Projects the complex number z onto the Riemann sphere. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::polar (T const &r, T const &theta=T())
 Returns a complex number with magnitude r and phase theta. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::exp (complex< T > const &z)
 Computes the complex exponential of z. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::log (complex< T > const &z)
 Computes the complex exponential of z. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::log10 (complex< T > const &z)
 Computes the complex exponential of z. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::sqrt (complex< T > const &z)
 Computes the square root of complex number z. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::cos (complex< T > const &z)
 Computes the cosine of complex z. More...
 
template<typename T >
CUTLASS_HOST_DEVICE complex< T > cutlass::platform::sin (complex< T > const &z)
 Computes the sin of complex z. More...
 
+
+ + + + diff --git a/docs/complex_8h_source.html b/docs/complex_8h_source.html new file mode 100644 index 000000000..6270d22da --- /dev/null +++ b/docs/complex_8h_source.html @@ -0,0 +1,123 @@ + + + + + + + +Cutlass: complex.h Source File + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+
+
complex.h
+
+
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 #pragma once
26 
27 #include <cuComplex.h>
28 #include "cutlass/cutlass.h"
29 #include <iosfwd>
30 
31 namespace cutlass {
32 namespace platform {
33 
35 
36 //
37 // Accessors for CUDA complex types
38 //
39 
41 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
42  // host-only type
44 float const &real(cuFloatComplex const &z) { return z.x; }
45 
47 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
48  // host-only type
50 float &real(cuFloatComplex &z) { return z.x; }
51 
53 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
54  // host-only type
56 double const &real(cuDoubleComplex const &z) { return z.x; }
57 
59 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
60  // host-only type
62 double &real(cuDoubleComplex &z) { return z.x; }
63 
65 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
66  // host-only type
68 float const &imag(cuFloatComplex const &z) { return z.y; }
69 
71 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
72  // host-only type
74 float &imag(cuFloatComplex &z) { return z.y; }
75 
77 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
78  // host-only type
80 double const &imag(cuDoubleComplex const &z) { return z.y; }
81 
83 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
84  // host-only type
86 double &imag(cuDoubleComplex &z) { return z.y; }
87 
89 
92 template <typename T>
93 class complex {
94  public:
96  typedef T value_type;
97 
98  private:
99  //
100  // Data members
101  //
102 
104  T _real;
105 
107  T _imag;
108 
109  public:
110 //
111 // Methods
112 //
113 
115 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
116  // host-only type
118  complex(T r = T(0), T i = T(0)) : _real(r), _imag(i) {}
119 
121 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
122  // host-only type
124  complex(cuFloatComplex const &z) : _real(platform::real(z)), _imag(platform::imag(z)) {}
125 
127 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
128  // host-only type
130  complex(cuDoubleComplex const &z) : _real(platform::real(z)), _imag(platform::imag(z)) {}
131 
133 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
134  // host-only type
136  T const &real() const { return _real; }
137 
139 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
140  // host-only type
142  T &real() { return _real; }
143 
145 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
146  // host-only type
148  T const &imag() const { return _imag; }
149 
151 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
152  // host-only type
154  T &imag() { return _imag; }
155 
157 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
158  // host-only type
160  operator cuFloatComplex() const { return make_cuFloatComplex(real(), imag()); }
161 
163 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
164  // host-only type
166  operator cuDoubleComplex() const { return make_cuDoubleComplex(real(), imag()); }
167 };
168 
169 //
170 // Accessors for complex template
171 //
172 
174 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
175  // host-only type
176 template <typename T>
177 CUTLASS_HOST_DEVICE T const &real(complex<T> const &z) {
178  return z.real();
179 }
180 
182 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
183  // host-only type
184 template <typename T>
186  return z.real();
187 }
188 
190 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
191  // host-only type
192 template <typename T>
193 CUTLASS_HOST_DEVICE T const &imag(complex<T> const &z) {
194  return z.imag();
195 }
196 
198 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
199  // host-only type
200 template <typename T>
202  return z.imag();
203 }
204 
205 //
206 // Output operators
207 //
208 
209 template <typename T>
210 std::ostream &operator<<(std::ostream &out, complex<T> const &z) {
211  T _r = real(z);
212  T _i = imag(z);
213  return out << _r << "+i" << _i;
214 }
215 
216 //
217 // Non-member operators defined for complex types
218 //
219 
221 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
222  // host-only type
223 template <typename T>
224 CUTLASS_HOST_DEVICE bool operator==(complex<T> const &lhs, complex<T> const &rhs) {
225  return real(lhs) == (rhs) && imag(lhs) == imag(rhs);
226 }
227 
229 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
230  // host-only type
231 template <typename T>
232 CUTLASS_HOST_DEVICE bool operator!=(complex<T> const &lhs, complex<T> const &rhs) {
233  return !(lhs == rhs);
234 }
235 
237 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
238  // host-only type
239 template <typename T>
241  return complex<T>(real(lhs) + real(rhs), imag(lhs) + imag(rhs));
242 }
243 
245 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
246  // host-only type
247 template <typename T>
249  return complex<T>(real(lhs) - real(rhs), imag(lhs) - imag(rhs));
250 }
251 
253 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
254  // host-only type
255 template <typename T>
257  return complex<T>(real(lhs) * real(rhs) - imag(lhs) * imag(rhs),
258  real(lhs) * imag(rhs) + imag(lhs) * real(rhs));
259 }
260 
262 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
263  // host-only type
264 template <typename T>
266  return complex<T>(real(lhs) * s, imag(lhs) * s);
267 }
268 
270 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
271  // host-only type
272 template <typename T>
274  return complex<T>(s * real(rhs), s * imag(rhs));
275 }
276 
278 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
279  // host-only type
280 template <typename T>
282  T d = (real(rhs) * (rhs) + imag(rhs) * imag(rhs));
283 
284  return complex<T>((real(lhs) * (rhs) + imag(lhs) * imag(rhs)) / d,
285  (imag(lhs) * (rhs)-real(lhs) * imag(rhs)) / d);
286 }
287 
289 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
290  // host-only type
291 template <typename T>
293  return complex<T>(real(lhs) / s, imag(lhs) / s);
294 }
295 
297 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
298  // host-only type
299 template <typename T>
301  T d = (real(rhs) * (rhs) + imag(rhs) * imag(rhs));
302 
303  return complex<T>((s * (rhs)) / d, -(s * imag(rhs)) / d);
304 }
305 
307 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
308  // host-only type
309 template <typename T>
311  lhs = (lhs + rhs);
312  return lhs;
313 }
314 
316 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
317  // host-only type
318 template <typename T>
320  lhs = (lhs - rhs);
321  return lhs;
322 }
323 
325 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
326  // host-only type
327 template <typename T>
329  lhs = (lhs * rhs);
330  return lhs;
331 }
332 
334 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
335  // host-only type
336 template <typename T>
338  lhs = (lhs * s);
339  return lhs;
340 }
341 
343 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
344  // host-only type
345 template <typename T>
347  lhs = (lhs / rhs);
348  return lhs;
349 }
350 
351 //
352 // Non-member functions defined for complex numbers
353 //
354 
356 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
357  // host-only type
358 template <typename T>
360  return sqrt(norm(z));
361 }
362 
364 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
365  // host-only type
366 template <typename T>
368  return atan2(imag(z), real(z));
369 }
370 
372 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
373  // host-only type
374 template <typename T>
376  return real(z) * real(z) + imag(z) * imag(z);
377 }
378 
380 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
381  // host-only type
382 template <typename T>
384  return complex<T>(real(z), -imag(z));
385 }
386 
388 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
389  // host-only type
390 template <typename T>
392  T d = real(z) * real(z) + imag(z) * imag(z) + T(1);
393  return complex<T>((T(2) * real(z)) / d, (T(2) * imag(z)) / d);
394 }
395 
397 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
398  // host-only type
399 template <typename T>
400 CUTLASS_HOST_DEVICE complex<T> polar(T const &r, T const &theta = T()) {
401  return complex<T>(r * cos(theta), r * sin(theta));
402 }
403 
405 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
406  // host-only type
407 template <typename T>
409  return complex<T>(real(z) * cos(imag(z)), real(z) * sin(imag(z)));
410 }
411 
413 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
414  // host-only type
415 template <typename T>
417  return complex<T>(log(abs(z)), arg(z));
418 }
419 
421 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
422  // host-only type
423 template <typename T>
425  return log(z) / T(log(T(10)));
426 }
427 
429 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
430  // host-only type
431 template <typename T>
433  return sqrt(T(2)) / T(2) *
434  complex<T>(sqrt(sqrt(norm(z)) + real(z)),
435  (imag(z) < 0 ? T(-1) : T(1)) * sqrt(sqrt(norm(z)) - real(z)));
436 }
437 
439 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
440  // host-only type
441 template <typename T>
443  return (exp(z) + exp(-z)) / T(2);
444 }
445 
447 #pragma hd_warning_disable // Suppresses warnings when attempting to instantiate complex<T> with a
448  // host-only type
449 template <typename T>
451  return (exp(-z) - exp(z)) * complex<T>(T(0), T(1) / T(2));
452 }
453 
455 
456 } // namespace platform
457 } // namespace cutlass
CUTLASS_HOST_DEVICE complex< T > proj(complex< T > const &z)
Projects the complex number z onto the Riemann sphere.
Definition: complex.h:391
+
Definition: convert.h:33
+
CUTLASS_HOST_DEVICE T & imag()
Accesses the imaginary part of the complex number.
Definition: complex.h:154
+
CUTLASS_HOST_DEVICE bool operator==(complex< T > const &lhs, complex< T > const &rhs)
Equality operator.
Definition: complex.h:224
+
CUTLASS_HOST_DEVICE T const & imag() const
Accesses the imaginary part of the complex number.
Definition: complex.h:148
+
CUTLASS_HOST_DEVICE complex< T > operator*(complex< T > const &lhs, complex< T > const &rhs)
Multiplication.
Definition: complex.h:256
+
CUTLASS_HOST_DEVICE complex< T > & operator-=(complex< T > &lhs, complex< T > const &rhs)
Subtraction.
Definition: complex.h:319
+
CUTLASS_HOST_DEVICE complex< T > operator-(complex< T > const &lhs, complex< T > const &rhs)
Subtraction.
Definition: complex.h:248
+
CUTLASS_HOST_DEVICE T & real()
Accesses the real part of the complex number.
Definition: complex.h:142
+
CUTLASS_HOST_DEVICE float const & real(cuFloatComplex const &z)
Returns the real part of the complex number.
Definition: complex.h:44
+
CUTLASS_HOST_DEVICE complex< T > sin(complex< T > const &z)
Computes the sin of complex z.
Definition: complex.h:450
+
CUTLASS_HOST_DEVICE complex(cuFloatComplex const &z)
Conversion from cuFloatComplex.
Definition: complex.h:124
+
CUTLASS_HOST_DEVICE complex< T > cos(complex< T > const &z)
Computes the cosine of complex z.
Definition: complex.h:442
+
CUTLASS_HOST_DEVICE complex< T > operator+(complex< T > const &lhs, complex< T > const &rhs)
Addition.
Definition: complex.h:240
+
CUTLASS_HOST_DEVICE complex< T > polar(T const &r, T const &theta=T())
Returns a complex number with magnitude r and phase theta.
Definition: complex.h:400
+
CUTLASS_HOST_DEVICE T const & real() const
Accesses the real part of the complex number.
Definition: complex.h:136
+
CUTLASS_HOST_DEVICE complex< T > & operator/=(complex< T > &lhs, complex< T > const &rhs)
Division.
Definition: complex.h:346
+
CUTLASS_HOST_DEVICE complex< T > sqrt(complex< T > const &z)
Computes the square root of complex number z.
Definition: complex.h:432
+
CUTLASS_HOST_DEVICE complex< T > & operator+=(complex< T > &lhs, complex< T > const &rhs)
Addition.
Definition: complex.h:310
+
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
+
CUTLASS_HOST_DEVICE float const & imag(cuFloatComplex const &z)
Returns the imaginary part of the complex number.
Definition: complex.h:68
+
CUTLASS_HOST_DEVICE complex< T > exp(complex< T > const &z)
Computes the complex exponential of z.
Definition: complex.h:408
+
CUTLASS_HOST_DEVICE complex< T > log10(complex< T > const &z)
Computes the complex exponential of z.
Definition: complex.h:424
+
CUTLASS_HOST_DEVICE T norm(complex< T > const &z)
Returns the squared magnitude.
Definition: complex.h:375
+
CUTLASS_HOST_DEVICE bool operator!=(complex< T > const &lhs, complex< T > const &rhs)
Inequality operator.
Definition: complex.h:232
+
CUTLASS_HOST_DEVICE T abs(complex< T > const &z)
Returns the magnitude of the complex number.
Definition: complex.h:359
+
CUTLASS_HOST_DEVICE complex< T > & operator*=(complex< T > &lhs, complex< T > const &rhs)
Multiplication.
Definition: complex.h:328
+
CUTLASS_HOST_DEVICE complex(cuDoubleComplex const &z)
Conversion from cuDoubleComplex.
Definition: complex.h:130
+
CUTLASS_HOST_DEVICE T arg(complex< T > const &z)
Returns the magnitude of the complex number.
Definition: complex.h:367
+
CUTLASS_HOST_DEVICE complex(T r=T(0), T i=T(0))
Constructor.
Definition: complex.h:118
+
Definition: complex.h:93
+
CUTLASS_HOST_DEVICE complex< T > log(complex< T > const &z)
Computes the complex exponential of z.
Definition: complex.h:416
+
T value_type
Type alias for scalar type.
Definition: complex.h:96
+
Basic include for CUTLASS macros.
+
CUTLASS_HOST_DEVICE complex< T > operator/(complex< T > const &lhs, complex< T > const &rhs)
Division.
Definition: complex.h:281
+
CUTLASS_HOST_DEVICE complex< T > conj(complex< T > const &z)
Returns the complex conjugate.
Definition: complex.h:383
+
+ + + + diff --git a/docs/convert_8h.html b/docs/convert_8h.html index 422c52017..cd3bf4bb8 100644 --- a/docs/convert_8h.html +++ b/docs/convert_8h.html @@ -82,7 +82,7 @@ $(function() {

Defines conversion operations among Fragments of different base type. More...

-
#include <cutlass/fragment.h>
+
#include "cutlass/fragment.h"

Go to the source code of this file.

@@ -103,7 +103,7 @@ Namespaces diff --git a/docs/convert_8h_source.html b/docs/convert_8h_source.html index 6e877d293..22ec9d4b8 100644 --- a/docs/convert_8h_source.html +++ b/docs/convert_8h_source.html @@ -76,7 +76,7 @@ $(function() {
convert.h
-Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <cutlass/fragment.h>
32 
33 namespace cutlass {
34 
36 
37 template <typename InputFragment_, typename OutputFragment_>
38 struct Convert {};
39 
41 
42 template <typename InputScalar_, typename OutputScalar_, int kScalars_>
43 struct Convert<Fragment<InputScalar_, kScalars_>, Fragment<OutputScalar_, kScalars_> > {
48 
50  CUTLASS_DEVICE Convert() {}
51 
53  CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
54  transform(src, 0, dst);
55  }
56 
58  template <typename Fragment_>
59  CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
60  for (int i = 0; i < kScalars_; ++i) {
61  dst[i] = static_cast<OutputScalar_>(src[i + offset]);
62  }
63  }
64 };
65 
67 
68 template <typename Fragment_>
69 struct Copy {
71  typedef Fragment_ InputFragment;
73  typedef Fragment_ OutputFragment;
74 
76  CUTLASS_DEVICE Copy() {}
77 
79  CUTLASS_DEVICE void transform(Fragment_ const& src, Fragment_& dst) { transform(src, 0, dst); }
80 
82  template <typename InputFragment_>
83  CUTLASS_DEVICE void transform(InputFragment_ const& src, int offset, Fragment_& dst) {
84  if (sizeof(typename Fragment_::Element) == 8) {
85  uint64_t const* src_ptr = reinterpret_cast<uint64_t const*>(&src[offset]);
86  uint64_t* dst_ptr = reinterpret_cast<uint64_t*>(&dst[0]);
87  for (int i = 0; i < sizeof(Fragment_) / 8; ++i) {
88  dst_ptr[i] = src_ptr[i];
89  }
90  } else {
91  uint32_t const* src_ptr = reinterpret_cast<uint32_t const*>(&src[offset]);
92  uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&dst[0]);
93  for (int i = 0; i < sizeof(Fragment_) / 4; ++i) {
94  dst_ptr[i] = src_ptr[i];
95  }
96  }
97  }
98 };
99 
101 
102 } // namespace cutlass
Definition: convert.h:33
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/fragment.h"
32 
33 namespace cutlass {
34 
36 
37 template <typename InputFragment_, typename OutputFragment_>
38 struct Convert {};
39 
41 
42 template <typename InputScalar_, typename OutputScalar_, int kScalars_>
43 struct Convert<Fragment<InputScalar_, kScalars_>, Fragment<OutputScalar_, kScalars_> > {
48 
50  CUTLASS_DEVICE Convert() {}
51 
53  CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
54  transform(src, 0, dst);
55  }
56 
58  template <typename Fragment_>
59  CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
60  for (int i = 0; i < kScalars_; ++i) {
61  dst[i] = static_cast<OutputScalar_>(src[i + offset]);
62  }
63  }
64 };
65 
67 
68 template <typename Fragment_>
69 struct Copy {
71  typedef Fragment_ InputFragment;
73  typedef Fragment_ OutputFragment;
74 
76  CUTLASS_DEVICE Copy() {}
77 
79  CUTLASS_DEVICE void transform(Fragment_ const& src, Fragment_& dst) { transform(src, 0, dst); }
80 
82  template <typename InputFragment_>
83  CUTLASS_DEVICE void transform(InputFragment_ const& src, int offset, Fragment_& dst) {
84  if (sizeof(typename Fragment_::Element) == 8) {
85  uint64_t const* src_ptr = reinterpret_cast<uint64_t const*>(&src[offset]);
86  uint64_t* dst_ptr = reinterpret_cast<uint64_t*>(&dst[0]);
87  for (int i = 0; i < sizeof(Fragment_) / 8; ++i) {
88  dst_ptr[i] = src_ptr[i];
89  }
90  } else {
91  uint32_t const* src_ptr = reinterpret_cast<uint32_t const*>(&src[offset]);
92  uint32_t* dst_ptr = reinterpret_cast<uint32_t*>(&dst[0]);
93  for (int i = 0; i < sizeof(Fragment_) / 4; ++i) {
94  dst_ptr[i] = src_ptr[i];
95  }
96  }
97  }
98 };
99 
101 
102 } // namespace cutlass
Definition: convert.h:33
Fragment< OutputScalar_, kScalars_ > OutputFragment
The output fragment.
Definition: convert.h:47
Definition: convert.h:69
CUTLASS_DEVICE void transform(Fragment_ const &src, Fragment_ &dst)
Transform a fragment.
Definition: convert.h:79
@@ -94,7 +94,7 @@ $(function() {
diff --git a/docs/coord_8h.html b/docs/coord_8h.html index 516503867..8bb9bea4d 100644 --- a/docs/coord_8h.html +++ b/docs/coord_8h.html @@ -83,7 +83,8 @@ $(function() {

A Coord is a coordinate of arbitrary rank into a tensor or matrix. More...

-
@@ -92,7 +93,7 @@ Classes - +
struct  cutlass::Identity
 Describes identity elements. More...
 
struct  cutlass::Coord< N_ >
struct  cutlass::Coord< Rank_, Index_ >
 Statically-sized array specifying Coords within a tensor. More...
 
@@ -115,23 +116,14 @@ Functions - - - - - - - - - - - - + + +
CUTLASS_HOST_DEVICE Coord< 4 > cutlass::make_Coord (int _0, int _1, int _2, int _3)
 Helper to make a 4-element coordinate. More...
 
CUTLASS_HOST_DEVICE Coord< 2 > cutlass::get_Coord_hw (Coord< 3 > const &coord)
 Getter. More...
 
CUTLASS_HOST_DEVICE Coord< 2 > cutlass::get_Coord_hw (Coord< 4 > const &coord)
 Getter. More...
 
CUTLASS_HOST_DEVICE Coord< 3 > cutlass::get_Coord_hwc (Coord< 4 > const &coord)
 Getter. More...
 
CUTLASS_HOST_DEVICE Coord< 3 > cutlass::get_Coord_dhw (Coord< 4 > const &coord)
 Getter. More...
 
template<typename Shape_ >
CUTLASS_HOST_DEVICE Coord< 3 > cutlass::make_Coord_from_shape ()
 
diff --git a/docs/coord_8h_source.html b/docs/coord_8h_source.html index 71ec92e1a..b0e2162cc 100644 --- a/docs/coord_8h_source.html +++ b/docs/coord_8h_source.html @@ -76,50 +76,54 @@ $(function() {
coord.h
-Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <cutlass/cutlass.h>
32 
33 namespace cutlass {
34 
36 
38 struct Identity {
41  enum Kind { Additive = 0, Multiplicative = 1 };
42 };
43 
45 
47 template <int N_>
48 struct Coord {
49  //
50  // Type and constant definitions
51  //
52 
53  static int const N = N_;
54 
55  //
56  // Data members
57  //
58 
60  int idx[N];
61 
62  //
63  // Methods
64  //
65 
68  Coord(int value = 0) {
69  for (int i = 0; i < N; ++i) {
70  idx[i] = value;
71  }
72  }
73 
76  Coord(int _idx[]) {
77  for (int i = 0; i < N; ++i) {
78  idx[i] = _idx[i];
79  }
80  }
81 
84  Coord operator+(Coord const& b) const {
85  Coord c;
86  for (int i = 0; i < N; ++i) {
87  c.idx[i] = idx[i] + b.idx[i];
88  }
89  return c;
90  }
91 
94  Coord operator-(Coord const& b) const {
95  Coord c;
96  for (int i = 0; i < N; ++i) {
97  c.idx[i] = idx[i] - b.idx[i];
98  }
99  return c;
100  }
101 
104  Coord operator*(Coord const& b) const {
105  Coord c;
106  for (int i = 0; i < N; ++i) {
107  c.idx[i] = idx[i] * b.idx[i];
108  }
109  return c;
110  }
111 
114  Coord operator/(Coord const& b) const {
115  Coord c;
116  for (int i = 0; i < N; ++i) {
117  c.idx[i] = idx[i] / b.idx[i];
118  }
119  return c;
120  }
121 
124  Coord& operator+=(Coord const& b) {
125  for (int i = 0; i < N; ++i) {
126  idx[i] += b.idx[i];
127  }
128  return *this;
129  }
130 
133  Coord& operator-=(Coord const& b) {
134  for (int i = 0; i < N; ++i) {
135  idx[i] -= b.idx[i];
136  }
137  return *this;
138  }
139 
142  Coord& operator*=(Coord const& b) {
143  for (int i = 0; i < N; ++i) {
144  idx[i] *= b.idx[i];
145  }
146  return *this;
147  }
148 
151  Coord& operator/=(Coord const& b) {
152  for (int i = 0; i < N; ++i) {
153  idx[i] /= b.idx[i];
154  }
155  return *this;
156  }
157 
159  CUTLASS_HOST_DEVICE int& operator[](int dim) { return idx[dim]; }
160 
162  CUTLASS_HOST_DEVICE int const& operator[](int dim) const { return idx[dim]; }
163 
165  template <typename T>
166  CUTLASS_HOST_DEVICE T dot(Coord const& b, T sum) const {
167  for (int i = 0; i < N; ++i) {
168  sum += idx[i] * b.idx[i];
169  }
170  return sum;
171  }
172 
174  template <typename T>
175  CUTLASS_HOST_DEVICE T dot(Coord const& b) const {
176  T sum = T(0);
177  for (int i = 0; i < N; ++i) {
178  sum += idx[i] * b.idx[i];
179  }
180  return sum;
181  }
182 
184  template <int Dim>
186  return idx[Dim];
187  }
188 
191  int& at(int dim) { return idx[dim]; }
192 
194  template <int Dim>
195  CUTLASS_HOST_DEVICE int const& at() const {
196  return idx[Dim];
197  }
198 
201  int const& at(int dim) const { return idx[dim]; }
202 
205  bool operator==(Coord<N> const& b) const {
206  bool equal = true;
207  for (int i = 0; equal && i < N; ++i) {
208  equal = (idx[i] == b.idx[i]);
209  }
210  return equal;
211  }
212 
215  bool operator!=(Coord<N> const& b) const { return !(*this == b); }
216 
219  Coord& clamp(Coord<N> const& max, Coord<N> const& min = Coord<N>()) {
220  for (int i = 0; i < N; ++i) {
221  idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]);
222  }
223  return *this;
224  }
225 
228  int count() const {
229  int product = idx[0];
230  for (int i = 1; i < N; ++i) {
231  product *= idx[i];
232  }
233  return product;
234  }
235 };
236 
238 
242  int values[1] = {_0};
243  return Coord<1>(values);
244 }
245 
248 Coord<2> make_Coord(int _0, int _1) {
249  int values[2] = {_0, _1};
250  return Coord<2>(values);
251 }
252 
255 Coord<3> make_Coord(int _0, int _1, int _2) {
256  int values[3] = {_0, _1, _2};
257  return Coord<3>(values);
258 }
259 
262 Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
263  int values[4] = {_0, _1, _2, _3};
264  return Coord<4>(values);
265 }
266 
268 
271 Coord<2> get_Coord_hw(Coord<3> const& coord) { return make_Coord(coord[1], coord[2]); }
272 
275 Coord<2> get_Coord_hw(Coord<4> const& coord) { return make_Coord(coord[1], coord[2]); }
276 
279 Coord<3> get_Coord_hwc(Coord<4> const& coord) { return make_Coord(coord[1], coord[2], coord[3]); }
280 
283 Coord<3> get_Coord_dhw(Coord<4> const& coord) { return make_Coord(coord[0], coord[1], coord[2]); }
284 
286 
287 } // namespace cutlass
CUTLASS_HOST_DEVICE int const & operator[](int dim) const
Member access operator.
Definition: coord.h:162
-
CUTLASS_HOST_DEVICE int count() const
Returns the product of all elements.
Definition: coord.h:228
-
Describes identity elements.
Definition: coord.h:38
-
CUTLASS_HOST_DEVICE constexpr const T & max(const T &a, const T &b)
std::max
Definition: platform.h:207
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/util/platform.h"
33 
34 namespace cutlass {
35 
37 
39 struct Identity {
42  enum Kind { Additive = 0, Multiplicative = 1 };
43 };
44 
46 
48 template <int Rank_, typename Index_ = int>
49 struct Coord {
50  //
51  // Type and constant definitions
52  //
53 
55  static int const kRank = Rank_;
56 
58  static int const N = Rank_;
59 
61  typedef Index_ Index;
62 
63  //
64  // Data members
65  //
66 
69 
70  //
71  // Methods
72  //
73 
76  Coord(Index value = 0) {
77  for (int i = 0; i < kRank; ++i) {
78  idx[i] = value;
79  }
80  }
81 
84  Coord(Index _idx[]) {
85  for (int i = 0; i < kRank; ++i) {
86  idx[i] = _idx[i];
87  }
88  }
89 
92  Coord(Coord<kRank> const &coord) {
93  for (int i = 0; i < kRank; ++i) {
94  idx[i] = coord[i];
95  }
96  }
97 
100  template <int Slice>
102  Coord<Slice> slice(int start = 0, Index identity = 0) const {
103  Coord<Slice> result;
104  for (int i = 0; i < Slice; ++i) {
105  if (i + start < kRank) {
106  slice[i] = idx[i + start];
107  }
108  else {
109  slice[i] = identity;
110  }
111  }
112  return result;
113  }
114 
117  operator bool() const {
118  for (int i = 0; i < kRank; ++i) {
119  if (idx[i]) {
120  return true;
121  }
122  }
123  return false;
124  }
125 
128  bool operator!() const {
129  for (int i = 0; i < kRank; ++i) {
130  if (idx[i]) {
131  return false;
132  }
133  }
134  return true;
135  }
136 
139  Coord operator+(Coord const& b) const {
140  Coord c;
141  for (int i = 0; i < kRank; ++i) {
142  c.idx[i] = idx[i] + b.idx[i];
143  }
144  return c;
145  }
146 
149  Coord operator-(Coord const& b) const {
150  Coord c;
151  for (int i = 0; i < kRank; ++i) {
152  c.idx[i] = idx[i] - b.idx[i];
153  }
154  return c;
155  }
156 
159  Coord operator*(Coord const& b) const {
160  Coord c;
161  for (int i = 0; i < kRank; ++i) {
162  c.idx[i] = idx[i] * b.idx[i];
163  }
164  return c;
165  }
166 
169  Coord operator/(Coord const& b) const {
170  Coord c;
171  for (int i = 0; i < kRank; ++i) {
172  c.idx[i] = idx[i] / b.idx[i];
173  }
174  return c;
175  }
176 
179  Coord& operator+=(Coord const& b) {
180  for (int i = 0; i < kRank; ++i) {
181  idx[i] += b.idx[i];
182  }
183  return *this;
184  }
185 
188  Coord& operator-=(Coord const& b) {
189  for (int i = 0; i < kRank; ++i) {
190  idx[i] -= b.idx[i];
191  }
192  return *this;
193  }
194 
197  Coord& operator*=(Coord const& b) {
198  for (int i = 0; i < kRank; ++i) {
199  idx[i] *= b.idx[i];
200  }
201  return *this;
202  }
203 
206  Coord& operator/=(Coord const& b) {
207  for (int i = 0; i < kRank; ++i) {
208  idx[i] /= b.idx[i];
209  }
210  return *this;
211  }
212 
214  CUTLASS_HOST_DEVICE Index& operator[](int dim) { return idx[dim]; }
215 
217  CUTLASS_HOST_DEVICE Index const& operator[](int dim) const { return idx[dim]; }
218 
220  template <typename T>
221  CUTLASS_HOST_DEVICE T dot(Coord const& b, T sum) const {
222  for (int i = 0; i < kRank; ++i) {
223  sum += idx[i] * b.idx[i];
224  }
225  return sum;
226  }
227 
229  template <typename T>
230  CUTLASS_HOST_DEVICE T dot(Coord const& b) const {
231  T sum = T(0);
232  for (int i = 0; i < kRank; ++i) {
233  sum += idx[i] * b.idx[i];
234  }
235  return sum;
236  }
237 
239  template <int Dim>
241  return idx[Dim];
242  }
243 
246  Index& at(int dim) { return idx[dim]; }
247 
249  template <int Dim>
250  CUTLASS_HOST_DEVICE Index const& at() const {
251  return idx[Dim];
252  }
253 
256  Index const& at(int dim) const { return idx[dim]; }
257 
260  bool operator==(Coord<kRank> const& b) const {
261  bool equal = true;
262  for (int i = 0; equal && i < kRank; ++i) {
263  equal = (idx[i] == b.idx[i]);
264  }
265  return equal;
266  }
267 
270  bool operator!=(Coord<kRank> const& b) const { return !(*this == b); }
271 
275  for (int i = 0; i < kRank; ++i) {
276  idx[i] = __NV_STD_MAX(__NV_STD_MIN(idx[i], max.idx[i]), min.idx[i]);
277  }
278  return *this;
279  }
280 
283  Index count() const {
284  Index product = idx[0];
285  for (int i = 1; i < kRank; ++i) {
286  product *= idx[i];
287  }
288  return product;
289  }
290 
293  bool operator<(Coord<kRank> const &b) const {
294  for (int i = 0; i < kRank; ++i) {
295  if (!(idx[i] < b[i])) {
296  return false;
297  }
298  }
299  return true;
300  }
301 
304  bool operator<=(Coord<kRank> const &b) const {
305  for (int i = 0; i < kRank; ++i) {
306  if (!(idx[i] <= b[i])) {
307  return false;
308  }
309  }
310  return true;
311  }
312 };
313 
315 
319  int values[1] = {_0};
320  return Coord<1>(values);
321 }
322 
325 Coord<2> make_Coord(int _0, int _1) {
326  int values[2] = {_0, _1};
327  return Coord<2>(values);
328 }
329 
332 Coord<3> make_Coord(int _0, int _1, int _2) {
333  int values[3] = {_0, _1, _2};
334  return Coord<3>(values);
335 }
336 
339 Coord<4> make_Coord(int _0, int _1, int _2, int _3) {
340  int values[4] = {_0, _1, _2, _3};
341  return Coord<4>(values);
342 }
343 
345 
346 template <typename Shape_>
348  return make_Coord(Shape_::kD, Shape_::kH, Shape_::kW);
349 }
350 
352 
353 } // namespace cutlass
Describes identity elements.
Definition: coord.h:39
+
CUTLASS_HOST_DEVICE constexpr const T & max(const T &a, const T &b)
std::max
Definition: platform.h:215
Definition: convert.h:33
-
CUTLASS_HOST_DEVICE bool operator==(Coord< N > const &b) const
Determines if two Coord<> objects are equal.
Definition: coord.h:205
-
CUTLASS_HOST_DEVICE Coord & operator+=(Coord const &b)
In-place addition.
Definition: coord.h:124
-
CUTLASS_HOST_DEVICE bool operator!=(Coord< N > const &b) const
Not equal.
Definition: coord.h:215
-
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:241
-
CUTLASS_HOST_DEVICE Coord< 3 > get_Coord_hwc(Coord< 4 > const &coord)
Getter.
Definition: coord.h:279
-
CUTLASS_HOST_DEVICE Coord< 3 > get_Coord_dhw(Coord< 4 > const &coord)
Getter.
Definition: coord.h:283
-
CUTLASS_HOST_DEVICE Coord & clamp(Coord< N > const &max, Coord< N > const &min=Coord< N >())
Clamps a coordinate to a range specified by maximum and minimum values.
Definition: coord.h:219
-
CUTLASS_HOST_DEVICE int const & at() const
Gets the index of a given Coord element.
Definition: coord.h:195
-
CUTLASS_HOST_DEVICE Coord operator/(Coord const &b) const
Element-wise division.
Definition: coord.h:114
-
Kind
Definition: coord.h:41
-
CUTLASS_HOST_DEVICE T dot(Coord const &b, T sum) const
Computes the dot product of two Coord instances.
Definition: coord.h:166
-
CUTLASS_HOST_DEVICE Coord(int _idx[])
Constructs from an array of integers.
Definition: coord.h:76
-
#define __NV_STD_MAX(a, b)
Select maximum(a, b)
Definition: platform.h:155
-
CUTLASS_HOST_DEVICE int & at(int dim)
Access via index; may limit unrolling potential.
Definition: coord.h:191
-
CUTLASS_HOST_DEVICE int & operator[](int dim)
Member access operator.
Definition: coord.h:159
-
CUTLASS_HOST_DEVICE Coord & operator-=(Coord const &b)
In-place subtraction.
Definition: coord.h:133
-
CUTLASS_HOST_DEVICE Coord operator*(Coord const &b) const
Element-wise multiplication.
Definition: coord.h:104
-
CUTLASS_HOST_DEVICE Coord(int value=0)
Default ctor initializes uniformly.
Definition: coord.h:68
-
CUTLASS_HOST_DEVICE Coord< 2 > get_Coord_hw(Coord< 3 > const &coord)
Getter.
Definition: coord.h:271
-
static int const N
Definition: coord.h:53
-
#define __NV_STD_MIN(a, b)
Select minimum(a, b)
Definition: platform.h:160
-
CUTLASS_HOST_DEVICE T dot(Coord const &b) const
Computes the dot product of two Coord instances.
Definition: coord.h:175
-
CUTLASS_HOST_DEVICE Coord operator-(Coord const &b) const
Element-wise subtraction.
Definition: coord.h:94
+
CUTLASS_HOST_DEVICE Coord operator-(Coord const &b) const
Element-wise subtraction.
Definition: coord.h:149
+
CUTLASS_HOST_DEVICE Index const & at(int dim) const
Access via index; may limit unrolling potential.
Definition: coord.h:256
+
CUTLASS_HOST_DEVICE Index const & operator[](int dim) const
Member access operator.
Definition: coord.h:217
+
CUTLASS_HOST_DEVICE Coord operator/(Coord const &b) const
Element-wise division.
Definition: coord.h:169
+
CUTLASS_HOST_DEVICE Index & operator[](int dim)
Member access operator.
Definition: coord.h:214
+
CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
Helper to make a 2-element coordinate.
Definition: coord.h:318
+
static int const kRank
Number of elements in Coord.
Definition: coord.h:55
+
Index_ Index
Index type used to store elements.
Definition: coord.h:61
+
CUTLASS_HOST_DEVICE Coord & operator*=(Coord const &b)
In-place multiplication.
Definition: coord.h:197
+
CUTLASS_HOST_DEVICE Index & at(int dim)
Access via index; may limit unrolling potential.
Definition: coord.h:246
+
C++ features that may be otherwise unimplemented for CUDA device functions.
+
CUTLASS_HOST_DEVICE Index count() const
Returns the product of all elements.
Definition: coord.h:283
+
CUTLASS_HOST_DEVICE Coord operator*(Coord const &b) const
Element-wise multiplication.
Definition: coord.h:159
+
Kind
Definition: coord.h:42
+
CUTLASS_HOST_DEVICE Coord< 3 > make_Coord_from_shape()
Definition: coord.h:347
+
CUTLASS_HOST_DEVICE bool operator==(Coord< kRank > const &b) const
Determines if two Coord<> objects are equal.
Definition: coord.h:260
+
static int const N
Number of elements in Coord, aliased for compatibility.
Definition: coord.h:58
+
#define __NV_STD_MAX(a, b)
Select maximum(a, b)
Definition: platform.h:163
+
Index idx[kRank]
Indices.
Definition: coord.h:68
+
#define __NV_STD_MIN(a, b)
Select minimum(a, b)
Definition: platform.h:168
+
CUTLASS_HOST_DEVICE Coord & operator-=(Coord const &b)
In-place subtraction.
Definition: coord.h:188
+
CUTLASS_HOST_DEVICE Coord & operator+=(Coord const &b)
In-place addition.
Definition: coord.h:179
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
-
CUTLASS_HOST_DEVICE constexpr const T & min(const T &a, const T &b)
std::min
Definition: platform.h:201
-
Definition: coord.h:41
-
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:48
-
CUTLASS_HOST_DEVICE int & at()
Gets the index of a given Coord element.
Definition: coord.h:185
-
int idx[N]
Indices.
Definition: coord.h:60
-
Definition: coord.h:41
-
CUTLASS_HOST_DEVICE int const & at(int dim) const
Access via index; may limit unrolling potential.
Definition: coord.h:201
+
CUTLASS_HOST_DEVICE bool operator!=(Coord< kRank > const &b) const
Not equal.
Definition: coord.h:270
+
CUTLASS_HOST_DEVICE constexpr const T & min(const T &a, const T &b)
std::min
Definition: platform.h:209
+
CUTLASS_HOST_DEVICE Index & at()
Gets the index of a given Coord element.
Definition: coord.h:240
+
CUTLASS_HOST_DEVICE Coord & operator/=(Coord const &b)
In-place division.
Definition: coord.h:206
+
Definition: coord.h:42
+
CUTLASS_HOST_DEVICE Coord< Slice > slice(int start=0, Index identity=0) const
Definition: coord.h:102
+
Statically-sized array specifying Coords within a tensor.
Definition: coord.h:49
+
CUTLASS_HOST_DEVICE Index const & at() const
Gets the index of a given Coord element.
Definition: coord.h:250
+
CUTLASS_HOST_DEVICE T dot(Coord const &b, T sum) const
Computes the dot product of two Coord instances.
Definition: coord.h:221
+
CUTLASS_HOST_DEVICE Coord(Index value=0)
Default ctor initializes uniformly.
Definition: coord.h:76
+
Definition: coord.h:42
+
CUTLASS_HOST_DEVICE Coord & clamp(Coord< kRank > const &max, Coord< kRank > const &min=Coord< kRank >())
Clamps a coordinate to a range specified by maximum and minimum values.
Definition: coord.h:274
+
CUTLASS_HOST_DEVICE Coord(Index _idx[])
Constructs from an array of integers.
Definition: coord.h:84
+
CUTLASS_HOST_DEVICE T dot(Coord const &b) const
Computes the dot product of two Coord instances.
Definition: coord.h:230
+
CUTLASS_HOST_DEVICE Coord operator+(Coord const &b) const
Element-wise addition.
Definition: coord.h:139
Basic include for CUTLASS macros.
-
CUTLASS_HOST_DEVICE Coord & operator*=(Coord const &b)
In-place multiplication.
Definition: coord.h:142
-
CUTLASS_HOST_DEVICE Coord operator+(Coord const &b) const
Element-wise addition.
Definition: coord.h:84
-
CUTLASS_HOST_DEVICE Coord & operator/=(Coord const &b)
In-place division.
Definition: coord.h:151
+
CUTLASS_HOST_DEVICE Coord(Coord< kRank > const &coord)
Constructs from an array of integers.
Definition: coord.h:92
+
CUTLASS_HOST_DEVICE bool operator!() const
Returns true if Coord is uniformly zero.
Definition: coord.h:128
diff --git a/docs/core__io_8h.html b/docs/core__io_8h.html index d71c39716..2f50d7851 100644 --- a/docs/core__io_8h.html +++ b/docs/core__io_8h.html @@ -73,6 +73,8 @@ $(function() {
core_io.h File Reference
@@ -83,51 +85,56 @@ $(function() { More...

#include <iosfwd>
#include <typeinfo>
-#include <cutlass/coord.h>
+#include "cutlass/coord.h"
+#include "cutlass/vector.h"

Go to the source code of this file.

+ + + + +

+Classes

struct  cutlass::ScalarIO< T >
 Helper to enable formatted printing of CUTLASS scalar types to an ostream. More...
 
+ + + +

+Namespaces

 cutlass
 
- - - + + + + + + + + + + + + + + + + + + + + + + + + + + +

Functions

template<int Rank>
std::ostream & operator<< (std::ostream &out, cutlass::Coord< Rank > const &coord)
 
template<int Rank>
std::ostream & cutlass::operator<< (std::ostream &out, Coord< Rank > const &coord)
 
template<typename T >
std::ostream & cutlass::operator<< (std::ostream &out, ScalarIO< T > const &scalar)
 Default printing to ostream. More...
 
template<>
std::ostream & cutlass::operator<< (std::ostream &out, ScalarIO< int8_t > const &scalar)
 Printing to ostream of int8_t as integer rather than character. More...
 
template<>
std::ostream & cutlass::operator<< (std::ostream &out, ScalarIO< uint8_t > const &scalar)
 Printing to ostream of uint8_t as integer rather than character. More...
 
template<>
std::ostream & cutlass::operator<< (std::ostream &out, ScalarIO< cutlass::Vector< cutlass::bin1_t, 32 > > const &scalar)
 Printing to ostream of vector of 1b elements. More...
 
template<>
std::ostream & cutlass::operator<< (std::ostream &out, ScalarIO< cutlass::Vector< cutlass::int4_t, 8 > > const &scalar)
 Printing to ostream of vector of 4b signed integer elements. More...
 
template<>
std::ostream & cutlass::operator<< (std::ostream &out, ScalarIO< cutlass::Vector< cutlass::uint4_t, 8 > > const &scalar)
 Printing to ostream of vector of 4b unsigned integer elements. More...
 
-

Function Documentation

- -

◆ operator<<()

- -
-
-
-template<int Rank>
- - - - - - - - - - - - - - - - - - -
std::ostream& operator<< (std::ostream & out,
cutlass::Coord< Rank > const & coord 
)
-
- -
-
diff --git a/docs/core__io_8h_source.html b/docs/core__io_8h_source.html index 7c076c94d..21b790113 100644 --- a/docs/core__io_8h_source.html +++ b/docs/core__io_8h_source.html @@ -76,11 +76,19 @@ $(function() {
core_io.h
-Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 #pragma once
26 
31 #pragma once
32 
33 #include <iosfwd>
34 #include <typeinfo>
35 
36 #include <cutlass/coord.h>
37 
38 template <int Rank>
39 std::ostream& operator<<(std::ostream& out, cutlass::Coord<Rank> const& coord) {
40  for (int i = 0; i < Rank; ++i) {
41  out << (i ? ", " : "") << coord.idx[i];
42  }
43  return out;
44 }
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <iosfwd>
32 #include <typeinfo>
33 
34 #include "cutlass/coord.h"
35 #include "cutlass/vector.h"
36 
37 namespace cutlass {
38 
40 
41 template <int Rank>
42 std::ostream& operator<<(std::ostream& out, Coord<Rank> const& coord) {
43  for (int i = 0; i < Rank; ++i) {
44  out << (i ? ", " : "") << coord.idx[i];
45  }
46  return out;
47 }
48 
50 
52 template <typename T>
53 struct ScalarIO {
54 
56  T value;
57 
59  ScalarIO() { }
60 
63 };
64 
66 
68 template <typename T>
69 inline std::ostream &operator<<(std::ostream &out, ScalarIO<T> const &scalar) {
70  return out << scalar.value;
71 }
72 
74 template <>
75 inline std::ostream &operator<<(std::ostream &out, ScalarIO<int8_t> const &scalar) {
76  return out << int(scalar.value);
77 }
78 
80 template <>
81 inline std::ostream &operator<<(std::ostream &out, ScalarIO<uint8_t> const &scalar) {
82  return out << unsigned(scalar.value);
83 }
84 
86 template <>
87 inline std::ostream &operator<<(
88  std::ostream &out,
90 
91  for (int i = 0; i < 32; i++) {
92  out << int(scalar.value[i]);
93  out << ((i != 31) ? ", " : "");
94  }
95  return out;
96 }
97 
99 template <>
100 inline std::ostream &operator<<(
101  std::ostream &out,
103 
104  for (int i = 0; i < 8; i++) {
105  out << int(scalar.value[i]);
106  out << ((i != 7) ? ", " : "");
107  }
108  return out;
109 }
110 
112 template <>
113 inline std::ostream &operator<<(
114  std::ostream &out,
116 
117  for (int i = 0; i < 8; i++) {
118  out << unsigned(scalar.value[i]);
119  out << ((i != 7) ? ", " : "");
120  }
121  return out;
122 }
123 
125 
126 } // namespace cutlass
Definition: convert.h:33
+
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
+
ScalarIO(T value)
Constructs from a value.
Definition: core_io.h:62
+
ScalarIO()
Default ctor.
Definition: core_io.h:59
+
std::ostream & operator<<(std::ostream &out, Coord< Rank > const &coord)
Definition: core_io.h:42
+
Helper to enable formatted printing of CUTLASS scalar types to an ostream.
Definition: core_io.h:53
+
Definition: vector.h:62
+
T value
Value to print.
Definition: core_io.h:56
+
Defines a 1D vector of elements held in the registers of each thread.
diff --git a/docs/cutlass_8h.html b/docs/cutlass_8h.html index bbb0463c9..419c9123f 100644 --- a/docs/cutlass_8h.html +++ b/docs/cutlass_8h.html @@ -73,8 +73,10 @@ $(function() {
cutlass.h File Reference
@@ -85,6 +87,13 @@ $(function() {

Go to the source code of this file.

+ + + + + +

+Classes

struct  DebugType< T >
 
struct  DebugValue< Value >
 
@@ -96,18 +105,26 @@ Macros - + + + - - + + +

Namespaces

 cutlass
 
#define CUTLASS_MINOR   0
 
#define CUTLASS_PATCH   0
#define CUTLASS_PATCH   1
 
#define CUTLASS_VERSION   ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
 
#define CUTLASS_HOST_DEVICE
 
#define CUTLASS_ASSERT(x)   assert(x)
 
#define CUTLASS_PRAGMA_UNROLL
 
#define CUTLASS_PRAGMA_NO_UNROLL
 
#define CUTLASS_ASSERT(x)   assert(x)
 
#define CUTLASS_GEMM_LOOP   CUTLASS_PRAGMA_NO_UNROLL
 
+ + + +

+Functions

template<typename T >
void DebugTypeFunc (T const &t)
 

Macro Definition Documentation

@@ -126,6 +143,20 @@ Macros
+
+
+ +

◆ CUTLASS_GEMM_LOOP

+ +
+
+ + + + +
#define CUTLASS_GEMM_LOOP   CUTLASS_PRAGMA_NO_UNROLL
+
+
@@ -177,7 +208,7 @@ Macros
- +
#define CUTLASS_PATCH   0#define CUTLASS_PATCH   1
@@ -224,12 +255,33 @@ Macros
+
+
+

Function Documentation

+ +

◆ DebugTypeFunc()

+ +
+
+
+template<typename T >
+ + + + + + + + +
void DebugTypeFunc (T const & t)
+
+
diff --git a/docs/cutlass_8h_source.html b/docs/cutlass_8h_source.html index d2f442295..9c9fb2b29 100644 --- a/docs/cutlass_8h_source.html +++ b/docs/cutlass_8h_source.html @@ -76,11 +76,14 @@ $(function() {
cutlass.h
-Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 
30 #pragma once
31 
33 
34 #define CUTLASS_MAJOR 1
35 #define CUTLASS_MINOR 0
36 #define CUTLASS_PATCH 0
37 #define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
38 
39 #ifdef __NVCC__
40 #define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__
41 #define CUTLASS_DEVICE __forceinline__ __device__
42 #elif defined(__CUDACC_RTC__)
43 #define CUTLASS_HOST_DEVICE __forceinline__ __device__
44 #define CUTLASS_DEVICE __forceinline__ __device__
45 #else
46 #define CUTLASS_HOST_DEVICE
47 // CUTLASS_DEVICE is an error if not compiling device code
48 #endif
49 
50 // CUTLASS_PRAGMA_UNROLL inserts a CUTLASS_PRAGMA_UNROLL if supported by the compiler
51 #if defined(__CUDA_ARCH__)
52 #if defined(_MSC_VER)
53 #define CUTLASS_PRAGMA_UNROLL __pragma("unroll")
54 #define CUTLASS_PRAGMA_NO_UNROLL __pragma("unroll 1")
55 #else
56 #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll")
57 #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1")
58 #endif
59 #else
60 #define CUTLASS_PRAGMA_UNROLL
61 #define CUTLASS_PRAGMA_NO_UNROLL
62 #endif
63 
64 #define CUTLASS_ASSERT(x) assert(x)
65 
66 namespace cutlass {
67 
69 static const int kWarpSize = 32;
70 
71 } // namespace cutlass
72 
Definition: convert.h:33
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 
30 #pragma once
31 
33 
34 #define CUTLASS_MAJOR 1
35 #define CUTLASS_MINOR 0
36 #define CUTLASS_PATCH 1
37 #define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH)
38 
39 #ifdef __NVCC__
40 #define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__
41 #define CUTLASS_DEVICE __forceinline__ __device__
42 #elif defined(__CUDACC_RTC__)
43 #define CUTLASS_HOST_DEVICE __forceinline__ __device__
44 #define CUTLASS_DEVICE __forceinline__ __device__
45 #else
46 #define CUTLASS_HOST_DEVICE
47 // CUTLASS_DEVICE is an error if not compiling device code
48 #endif
49 
50 #define CUTLASS_ASSERT(x) assert(x)
51 
52 // CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler.
53 #if defined(__CUDA_ARCH__)
54 #if defined(_MSC_VER)
55 #define CUTLASS_PRAGMA_UNROLL __pragma("unroll")
56 #define CUTLASS_PRAGMA_NO_UNROLL __pragma("unroll 1")
57 #else
58 #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll")
59 #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1")
60 #endif
61 #else
62 #define CUTLASS_PRAGMA_UNROLL
63 #define CUTLASS_PRAGMA_NO_UNROLL
64 #endif
65 
66 #define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL
67 
68 // A small helper class to dump a type at compile time
69 // Usage:: DumpType<Class>::Class
70 template <typename T>
71 struct DebugType {};
72 
73 template <typename T>
74 void DebugTypeFunc(T const& t) {
75  T::t;
76 }
77 
78 // A small helper class to dump a compile time constant at compile time
79 // Usage: DumpValue<Class::kConstant>::kConstant
80 template <int Value>
81 struct DebugValue {};
82 
83 namespace cutlass {
84 
86 static const int kWarpSize = 32;
87 
88 } // namespace cutlass
89 
Definition: convert.h:33
+
Definition: cutlass.h:81
+
Definition: cutlass.h:71
+
void DebugTypeFunc(T const &t)
Definition: cutlass.h:74
diff --git a/docs/cutlass__math_8h.html b/docs/cutlass__math_8h.html index 953b0d4c7..c4dbc54b0 100644 --- a/docs/cutlass__math_8h.html +++ b/docs/cutlass__math_8h.html @@ -83,7 +83,7 @@ $(function() {

Math utilities. More...

-
#include <cutlass/util/platform.h>
+

Go to the source code of this file.

@@ -103,6 +103,10 @@ Classes + + + +
 
struct  cutlass::divide_assert< Dividend, Divisor >
 
struct  cutlass::Min< A, B >
 
struct  cutlass::Max< A, B >
 
@@ -120,11 +124,17 @@ Functions + + + + + +

Namespaces

template<typename value_t >
CUTLASS_HOST_DEVICE value_t cutlass::lcm (value_t a, value_t b)
 
template<typename value_t >
CUTLASS_HOST_DEVICE value_t cutlass::clz (value_t x)
 
template<typename value_t >
CUTLASS_HOST_DEVICE value_t cutlass::find_log2 (value_t x)
 
diff --git a/docs/cutlass__math_8h_source.html b/docs/cutlass__math_8h_source.html index 2809a8456..8381f641a 100644 --- a/docs/cutlass__math_8h_source.html +++ b/docs/cutlass__math_8h_source.html @@ -76,27 +76,33 @@ $(function() {
cutlass_math.h
-Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 
26 #pragma once
27 
33 #include <cutlass/util/platform.h>
34 
35 namespace cutlass {
36 
37 /******************************************************************************
38  * Static math utilities
39  ******************************************************************************/
40 
44 template <int N>
45 struct is_pow2 : platform::integral_constant<bool, (N & (N - 1)) == 0> {};
46 
50 template <int N, int CurrentVal = N, int Count = 0>
51 struct log2_down {
53  enum { value = log2_down<N, (CurrentVal >> 1), Count + 1>::value };
54 };
55 
56 // Base case
57 template <int N, int Count>
58 struct log2_down<N, 1, Count> {
59  enum { value = Count };
60 };
61 
65 template <int N, int CurrentVal = N, int Count = 0>
66 struct log2_up {
68  enum { value = log2_up<N, (CurrentVal >> 1), Count + 1>::value };
69 };
70 
71 // Base case
72 template <int N, int Count>
73 struct log2_up<N, 1, Count> {
74  enum { value = ((1 << Count) < N) ? Count + 1 : Count };
75 };
76 
80 template <int N>
81 struct sqrt_est {
82  enum { value = 1 << (log2_up<N>::value / 2) };
83 };
84 
89 template <int Dividend, int Divisor>
90 struct divide_assert {
91  enum { value = Dividend / Divisor };
92 
93  static_assert((Dividend % Divisor == 0), "Not an even multiple");
94 };
95 
96 /******************************************************************************
97  * Rounding
98  ******************************************************************************/
99 
103 template <typename dividend_t, typename divisor_t>
104 CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor) {
105  return ((dividend + divisor - 1) / divisor) * divisor;
106 }
107 
111 template <typename value_t>
112 CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) {
113  for (;;) {
114  if (a == 0) return b;
115  b %= a;
116  if (b == 0) return a;
117  a %= b;
118  }
119 }
120 
124 template <typename value_t>
125 CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) {
126  value_t temp = gcd(a, b);
127 
128  return temp ? (a / temp * b) : 0;
129 }
130 
131 } // namespace cutlass
Definition: cutlass_math.h:91
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 
26 #pragma once
27 
33 #include "cutlass/util/platform.h"
34 
35 namespace cutlass {
36 
37 /******************************************************************************
38  * Static math utilities
39  ******************************************************************************/
40 
44 template <int N>
45 struct is_pow2 : platform::integral_constant<bool, (N & (N - 1)) == 0> {};
46 
50 template <int N, int CurrentVal = N, int Count = 0>
51 struct log2_down {
53  enum { value = log2_down<N, (CurrentVal >> 1), Count + 1>::value };
54 };
55 
56 // Base case
57 template <int N, int Count>
58 struct log2_down<N, 1, Count> {
59  enum { value = Count };
60 };
61 
65 template <int N, int CurrentVal = N, int Count = 0>
66 struct log2_up {
68  enum { value = log2_up<N, (CurrentVal >> 1), Count + 1>::value };
69 };
70 
71 // Base case
72 template <int N, int Count>
73 struct log2_up<N, 1, Count> {
74  enum { value = ((1 << Count) < N) ? Count + 1 : Count };
75 };
76 
80 template <int N>
81 struct sqrt_est {
82  enum { value = 1 << (log2_up<N>::value / 2) };
83 };
84 
89 template <int Dividend, int Divisor>
90 struct divide_assert {
91  enum { value = Dividend / Divisor };
92 
93  static_assert((Dividend % Divisor == 0), "Not an even multiple");
94 };
95 
96 /******************************************************************************
97  * Rounding
98  ******************************************************************************/
99 
103 template <typename dividend_t, typename divisor_t>
104 CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor) {
105  return ((dividend + divisor - 1) / divisor) * divisor;
106 }
107 
111 template <typename value_t>
112 CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b) {
113  for (;;) {
114  if (a == 0) return b;
115  b %= a;
116  if (b == 0) return a;
117  a %= b;
118  }
119 }
120 
124 template <typename value_t>
125 CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b) {
126  value_t temp = gcd(a, b);
127 
128  return temp ? (a / temp * b) : 0;
129 }
130 
136 template <typename value_t>
137 CUTLASS_HOST_DEVICE value_t clz(value_t x) {
138  for (int i = 31; i >= 0; --i) {
139  if ((1 << i) & x) return 31 - i;
140  }
141  return 32;
142 }
143 
144 template <typename value_t>
145 CUTLASS_HOST_DEVICE value_t find_log2(value_t x) {
146  int a = 31 - clz(x);
147  a += (x & (x - 1)) != 0; // Round up, add 1 if not a power of 2.
148  return a;
149 }
150 
151 /******************************************************************************
152  * Min/Max
153  ******************************************************************************/
154 
155 template <int A, int B>
156 struct Min {
157  static int const kValue = (A < B) ? A : B;
158 };
159 
160 template <int A, int B>
161 struct Max {
162  static int const kValue = (A > B) ? A : B;
163 };
164 
165 } // namespace cutlass
Definition: cutlass_math.h:91
Definition: convert.h:33
+
static int const kValue
Definition: cutlass_math.h:157
+
CUTLASS_HOST_DEVICE value_t find_log2(value_t x)
Definition: cutlass_math.h:145
Definition: cutlass_math.h:51
C++ features that may be otherwise unimplemented for CUDA device functions.
+
Definition: cutlass_math.h:156
Definition: cutlass_math.h:53
CUTLASS_HOST_DEVICE value_t lcm(value_t a, value_t b)
Definition: cutlass_math.h:125
CUTLASS_HOST_DEVICE dividend_t round_nearest(dividend_t dividend, divisor_t divisor)
Definition: cutlass_math.h:104
Definition: cutlass_math.h:68
-
std::integral_constant
Definition: platform.h:274
+
std::integral_constant
Definition: platform.h:282
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
-
#define static_assert(__e, __m)
Definition: platform.h:145
+
#define static_assert(__e, __m)
Definition: platform.h:153
+
Definition: cutlass_math.h:161
Definition: cutlass_math.h:82
CUTLASS_HOST_DEVICE value_t gcd(value_t a, value_t b)
Definition: cutlass_math.h:112
Definition: cutlass_math.h:90
Definition: cutlass_math.h:66
+
CUTLASS_HOST_DEVICE value_t clz(value_t x)
Definition: cutlass_math.h:137
Definition: cutlass_math.h:45
+
static int const kValue
Definition: cutlass_math.h:162
Definition: cutlass_math.h:81
diff --git a/docs/debug_8h.html b/docs/debug_8h.html index 1f88396ab..81ed9f3ca 100644 --- a/docs/debug_8h.html +++ b/docs/debug_8h.html @@ -231,7 +231,7 @@ Functions
- + - +

Classes

struct  cutlass::gemm::DgemmConfig< OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_ >
struct  cutlass::gemm::DgemmConfig< OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_ >
 
struct  cutlass::gemm::DgemmTraits< kLayoutA_, kLayoutB_, OutputTile_, EpilogueFunctor_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_, Index_, GemmConfig_, GemmEpilogueTraits_ >
struct  cutlass::gemm::DgemmTraits< kLayoutA_, kLayoutB_, OutputTile_, EpilogueFunctor_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_, Index_, GemmConfig_, GemmEpilogueTraits_ >
 
diff --git a/docs/dgemm__traits_8h_source.html b/docs/dgemm__traits_8h_source.html index 9cf2c8738..d7cdbe529 100644 --- a/docs/dgemm__traits_8h_source.html +++ b/docs/dgemm__traits_8h_source.html @@ -76,26 +76,26 @@ $(function() {
dgemm_traits.h
-Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <cutlass/gemm/gemm.h>
37 
38 namespace cutlass {
39 namespace gemm {
40 
42 
43 template <
45  typename OutputTile_,
47  typename AccumulatorsPerThread_,
49  int kScalarsPerLdgA_ = 1,
51  int kScalarsPerLdgB_ = 1>
53  : public GemmConfig<
55  double,
57  double,
59  double,
61  double,
63  OutputTile_,
65  ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, double, double, double>,
67  kScalarsPerLdgA_,
69  kScalarsPerLdgA_,
71  2,
73  kScalarsPerLdgB_,
75  kScalarsPerLdgB_,
77  2,
79  1,
81  2,
83  1,
85  2> {};
86 
88 
89 template <
91  MatrixLayout::Kind kLayoutA_,
93  MatrixLayout::Kind kLayoutB_,
95  typename OutputTile_ = Shape<8, 64, 128>,
97  typename EpilogueFunctor_ = LinearScaling<double>,
99  typename AccumulatorsPerThread_ = Shape<8, 8, 8>,
101  int kScalarsPerLdgA_ = 1,
103  int kScalarsPerLdgB_ = 1,
105  typename Index_ = int,
107  typename GemmConfig_ =
110  typename GemmEpilogueTraits_ =
113  // The layout for A.
114  kLayoutA_,
115  // The layout for B.
116  kLayoutB_,
117  // The config.
118  GemmConfig_,
119  // The epilogue.
120  GemmEpilogue<GemmEpilogueTraits_>,
121  // The index.
122  Index_> {};
123 
125 
126 } // namespace gemm
127 } // namespace cutlass
Definition: convert.h:33
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/gemm/gemm.h"
37 
38 namespace cutlass {
39 namespace gemm {
40 
42 
43 template <
45  typename OutputTile_,
47  typename ThreadGemmShape_,
49  int kScalarsPerLdgA_ = 1,
51  int kScalarsPerLdgB_ = 1>
53  : public GemmConfig<
55  double,
57  double,
59  double,
61  double,
63  OutputTile_,
65  ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, double, double, double>,
67  kScalarsPerLdgA_,
69  kScalarsPerLdgA_,
71  2,
73  kScalarsPerLdgB_,
75  kScalarsPerLdgB_,
77  2,
79  1,
81  2,
83  1,
85  2,
87  false,
89  false,
91  false
92  >{};
93 
95 
96 template <
98  MatrixLayout::Kind kLayoutA_,
100  MatrixLayout::Kind kLayoutB_,
102  typename OutputTile_ = Shape<8, 64, 128>,
104  typename EpilogueFunctor_ = LinearScaling<double>,
106  typename ThreadGemmShape_ = Shape<8, 8, 8>,
108  int kScalarsPerLdgA_ = 1,
110  int kScalarsPerLdgB_ = 1,
112  typename Index_ = int,
114  typename GemmConfig_ =
117  typename GemmEpilogueTraits_ =
120  // The layout for A.
121  kLayoutA_,
122  // The layout for B.
123  kLayoutB_,
124  // The config.
125  GemmConfig_,
126  // The epilogue.
127  GemmEpilogue<GemmEpilogueTraits_>,
128  // The index.
129  Index_> {};
130 
132 
133 } // namespace gemm
134 } // namespace cutlass
Definition: convert.h:33
Defines iterators for efficiently loading and storing to global memory.
Defines structural properties of complete GEMM computation.
Template implementing matrix multiply-add operations on fragments.
Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
Defines iterators for efficiently loading and storing tiles to and from shared memory.
-
Definition: gemm_traits.h:79
-
Definition: dgemm_traits.h:112
+
Definition: gemm_config.h:76
+
Definition: dgemm_traits.h:119
Definition: dgemm_traits.h:52
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
-
Definition: gemm_epilogue_traits.h:300
-
Kind
Definition: matrix_traits.h:36
-
Functor to compute linear combination of fragments.
Definition: linear_scaling.h:40
+
Definition: gemm_epilogue_traits.h:323
+
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
+
Functor to compute linear combination of fragments.
Definition: linear_scaling.h:51
Implements a software-pipelined efficient GEMM.
Defines structural properties of the GEMM epilogue.
-
Definition: gemm_traits.h:723
+
Definition: gemm_traits.h:650
diff --git a/docs/dir_1417ee5ebebc309c36b7962f26a92c39.html b/docs/dir_1417ee5ebebc309c36b7962f26a92c39.html index d7393ef13..6555e36cf 100644 --- a/docs/dir_1417ee5ebebc309c36b7962f26a92c39.html +++ b/docs/dir_1417ee5ebebc309c36b7962f26a92c39.html @@ -101,15 +101,15 @@ Files
- - - + + + @@ -128,12 +128,24 @@ Files + + + + + + + + + + + + @@ -143,11 +155,20 @@ Files + + + + + + + + +

@@ -109,7 +109,7 @@ Namespaces

file  fragment.h [code]
 Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers.
 
file  fragment_load_store.h [code]
 Defines accessors for loading and storing fragments to memory efficiently.
 
file  fragment_multiply_add.h [code]
 Defines multiply-add operations on fragments within a thread.
 
file  iterator_access.h [code]
 Free functions for loading and storing to implementations of tile iteartor concepts.
 
file  kernel_launch.h [code]
 Defines structures and helpers to launch CUDA kernels within CUTLASS.
 
file  load_store.h [code]
 Defines abstractions for efficiently loading and storing vectors to memory.
 
file  tensor_ref.h [code]
 Defines a structure containing strides, bounds, and a pointer to tensor data.
 
file  tensor_ref_collection.h [code]
 Introduces TensorRefCollection concept and defines TensorRefBatch and TensorRefArray.
 
file  tensor_view.h [code]
 Defines a structure containing strides and a pointer to tensor data.
 
file  tile_allocation.h [code]
 Defines a fragment based on a Shape<> template.
 
file  tile_coord.h [code]
 Defines a coordinate used for the CUTLASS 4-D tile structure.
 
file  tile_iterator.h [code]
 Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
 
file  tile_stream.h [code]
 Implements the tile stream concept, composing an iterator with a transformation. Offers split-phase semantics, separating the initiation of an asynchronous memory operation with a fence forcing it to complete.
 
file  tile_traits_standard.h [code]
 Defines tile traits for several tile partitioning arrangements of threads expected to achieve efficient streaming performance.
 
file  wmma_matrix.h [code]
 Abstractions for loading and storing matrices using the CUDA WMMA API.
 
file  zip_fragment.h [code]
 Models a pair of fragments.
 
file  zip_tensor_ref.h [code]
 Defines a structure containing a pair of TensorRef-like objects.
 
file  zip_tile_iterator.h [code]
 Constructs an iterator that owns two tile iterator instances.
 
diff --git a/docs/dir_18d6a367a3982a494d65599933fc67a3.html b/docs/dir_18d6a367a3982a494d65599933fc67a3.html index 161267475..b606ad3e0 100644 --- a/docs/dir_18d6a367a3982a494d65599933fc67a3.html +++ b/docs/dir_18d6a367a3982a494d65599933fc67a3.html @@ -85,9 +85,24 @@ Files
file  dgemm_traits.h [code]
 Defines structural traits of double-precision GEMM.
 
file  fp16_sgemm_multiply_add.h [code]
 Template implementing matrix multiply-add operations on fragments.
 
file  fp16_sgemm_traits.h [code]
 Defies structural properties of single-precision GEMM where any number of the input/output could be fp16 or fp32. The accumulator type stays in fp32.
 
file  gemm.h [code]
 Implements a software-pipelined efficient GEMM.
 
file  gemm_config.h [code]
 Defines properties of GEMM computation that impose some constraints on caller.
 
file  gemm_coord.h [code]
 GemmCoord is a structure derived from Coord<4> that specifies a location within the coordinate system of a GEMM problem.
 
file  gemm_desc.h [code]
 Implements a software-pipelined efficient GEMM.
 
file  gemm_epilogue.h [code]
 Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the computed matrix product.
 
file  gemm_shared_tile.h [code]
 Defines iterators for efficiently loading and storing tiles to and from shared memory.
 
file  gemm_stream_pair.h [code]
 Defines a pair of GEMM tile streams.
 
file  gemm_traits.h [code]
 Defines structural properties of complete GEMM computation.
 
file  hgemm_traits.h [code]
 Defies structural properties of half-precision GEMM computation.
 
file  identity_block_swizzle.h [code]
 Defies functors for mapping blockIdx to partitions of the GEMM computation.
 
file  igemm_epilogue.h [code]
 Defines the epilogue phase of the GEMM computation for IGEMM, supporting integer and floating-point output matrix formats.
 
file  linear_scaling.h [code]
 Implements the BLAS linear scaling function alpha*AB + beta*C.
 
file  linear_scaling_device_ptr.h [code]
 Implements the BLAS linear scaling function alpha*AB + beta*C.
 
file  scalar_or_pointer.h [code]
 Implements the BLAS linear scaling function alpha*AB + beta*C.
 
file  sgemm_traits.h [code]
 Defies structural properties of single-precision GEMM.
 
file  thread_multiply_add.h [code]
 Template implementing matrix multiply-add operations on fragments.
 
file  threadblock_swizzle.h [code]
 Defies functors for mapping blockIdx to partitions of the GEMM computation.
 
file  wmma_gemm_epilogue_traits.h [code]
 Defines structural properties of WMMA GEMM's epilogue phase.
 
+ + + + @@ -92,7 +96,7 @@ Files diff --git a/docs/files.html b/docs/files.html index 2c06de5a8..101952090 100644 --- a/docs/files.html +++ b/docs/files.html @@ -75,62 +75,79 @@ $(function() {
Here is a list of all files with brief descriptions:

Files

file  complex.h [code]
 
file  cutlass_math.h [code]
 Math utilities.
 
file  debug.h [code]
 Debugging and logging functionality.
 
file  numeric_types.h [code]
 
file  platform.h [code]
 C++ features that may be otherwise unimplemented for CUDA device functions.
 
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
 clear_accumulators.hDefines abstractions for efficiently clearing accumulator tiles
 convert.hDefines conversion operations among Fragments of different base type
 coord.hA Coord is a coordinate of arbitrary rank into a tensor or matrix
 core_io.hHelpers for printing cutlass/core objects
 cutlass.hBasic include for CUTLASS macros
 cutlass_math.hMath utilities
 debug.hDebugging and logging functionality
 dgemm_traits.hDefines structural traits of double-precision GEMM
 fragment.hDefines Fragment, a statically-sized array for storing parts of matrices within a thread's registers
 fragment_load_store.hDefines accessors for loading and storing fragments to memory efficiently
 fragment_multiply_add.hDefines multiply-add operations on fragments within a thread
 gemm.hImplements a software-pipelined efficient GEMM
 gemm_epilogue.hImplements the epilogue phase of the GEMM kernel that efficiently updates global memory with the computed matrix product
 gemm_epilogue_traits.hDefines structural properties of the GEMM epilogue
 gemm_global_stream.hImplements efficient loading of the thread block-level tile from global memory and storing to shared memory
 gemm_global_tile.hDefines iterators for efficiently loading and storing to global memory
 gemm_operand.hDefines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory
 gemm_shared_stream.hDefines abstractions for managing loading and storing fragments to shared memory in the efficient GEMM pipeline
 gemm_shared_tile.hDefines iterators for efficiently loading and storing tiles to and from shared memory
 gemm_traits.hDefines structural properties of complete GEMM computation
 hgemm_global_tile.hTile traits used to construct global tile iterator for HGEMM. This is intended to partition the thread block-level tile into 2D subtiles loaded by the threads and facilitate memory accesses larger than 16 bits
 hgemm_multiply_add.hSpecialization implementing multiply-add operation on half-precision floating point fragments
 hgemm_swizzle.hTransposes a tile of 16b elements. Used by HGEMM to construct a K-strided layout in shared memory for multiplicands
 hgemm_traits.hDefies structural properties of half-precision GEMM computation
 identity_block_swizzle.hDefies functors for mapping blockIdx to partitions of the GEMM computation
 igemm_epilogue.hDefines the epilogue phase of the GEMM computation for IGEMM, supporting integer and floating-point output matrix formats
 igemm_global_tile.hImplements tile iterators to partition the thread block tile into 2D subtiles and efficiently load each. Applies permute transformation to construct 'interleaved K-strided' data layout in which 4-element dot products from the same K index are arranged in consecutive locations within shared memory
 igemm_multiply_add.hImplements matrix multiply accumulate operation of 8-bit integer data using DP4A instruction
 igemm_swizzle.hTransposes a fragment of data containing packed 8-bit integer elements
 igemm_traits.hDefies structural properties of mixed-precision integer GEMM. Multiplicands are assumed to be packed 8bit integers, accumulators are assumed to be 32b signed integers, and output formats vary
 iterator_access.hFree functions for loading and storing to implementations of tile iteartor concepts
 linear_scaling.hImplements the BLAS linear scaling function alpha*AB + beta*C
 load_store.hDefines abstractions for efficiently loading and storing vectors to memory
 matrix_traits.hDefines properties of matrices used to denote layout and operands to GEMM kernels
 platform.hC++ features that may be otherwise unimplemented for CUDA device functions
 predicate_vector.hDefines container classes and iterators for managing a statically sized vector of boolean predicates
 reshape_tile.hDefines a type for restructuring a tile
 sgemm_traits.hDefies structural properties of single-precision GEMM
 shape.hDefines Shape implementing the Layout concept for representing a 4D hypercube of objects
 tensor_ref.hDefines a structure containing strides, bounds, and a pointer to tensor data
 tensor_view.hDefines a structure containing strides and a pointer to tensor data
 thread_multiply_add.hTemplate implementing matrix multiply-add operations on fragments
 tile_iterator.hDefines the Tile Traits concept and iterators for loading and storing to tiles efficiently
 tile_traits_standard.hDefines tile traits for several tile partitioning arrangements of threads expected to achieve efficient streaming performance
 vector.hDefines a 1D vector of elements held in the registers of each thread
 wmma_gemm_epilogue_traits.hDefines structural properties of WMMA GEMM's epilogue phase
 wmma_gemm_global_tile.hDefines tile iterator traits for loading thread block-level tile from global memory
 wmma_gemm_multiply_add.hImplements warp-level matrix multiply-accumulate operation using CUDA WMMA API
 wmma_gemm_shared_tile.hDefines iterator traits for efficiently loading and storing fragment to and from shared memory, specialized for WMMA GEMM
 wmma_gemm_traits.hDefies structural properties of GEMM targeting WMMA API in CUDA
 wmma_matrix.hAbstractions for loading and storing matrices using the CUDA WMMA API
 complex.h
 convert.hDefines conversion operations among Fragments of different base type
 coord.hA Coord is a coordinate of arbitrary rank into a tensor or matrix
 core_io.hHelpers for printing cutlass/core objects
 cutlass.hBasic include for CUTLASS macros
 cutlass_math.hMath utilities
 debug.hDebugging and logging functionality
 dgemm_traits.hDefines structural traits of double-precision GEMM
 fp16_sgemm_multiply_add.hTemplate implementing matrix multiply-add operations on fragments
 fp16_sgemm_traits.hDefies structural properties of single-precision GEMM where any number of the input/output could be fp16 or fp32. The accumulator type stays in fp32
 fragment.hDefines Fragment, a statically-sized array for storing parts of matrices within a thread's registers
 fragment_multiply_add.hDefines multiply-add operations on fragments within a thread
 gemm.hImplements a software-pipelined efficient GEMM
 gemm_config.hDefines properties of GEMM computation that impose some constraints on caller
 gemm_coord.hGemmCoord is a structure derived from Coord<4> that specifies a location within the coordinate system of a GEMM problem
 gemm_desc.hImplements a software-pipelined efficient GEMM
 gemm_epilogue.hImplements the epilogue phase of the GEMM kernel that efficiently updates global memory with the computed matrix product
 gemm_epilogue_traits.hDefines structural properties of the GEMM epilogue
 gemm_global_stream.hImplements efficient loading of the thread block-level tile from global memory and storing to shared memory
 gemm_global_tile.hDefines iterators for efficiently loading and storing to global memory
 gemm_operand.hDefines constant expressions for mapping GEMM problem size and strides onto pitch-linear memory
 gemm_shared_stream.hDefines abstractions for managing loading and storing fragments to shared memory in the efficient GEMM pipeline
 gemm_shared_tile.hDefines iterators for efficiently loading and storing tiles to and from shared memory
 gemm_stream_pair.hDefines a pair of GEMM tile streams
 gemm_traits.hDefines structural properties of complete GEMM computation
 hgemm_global_tile.hTile traits used to construct global tile iterator for HGEMM. This is intended to partition the thread block-level tile into 2D subtiles loaded by the threads and facilitate memory accesses larger than 16 bits
 hgemm_multiply_add.hSpecialization implementing multiply-add operation on half-precision floating point fragments
 hgemm_swizzle.hTransposes a tile of 16b elements. Used by HGEMM to construct a K-strided layout in shared memory for multiplicands
 hgemm_traits.hDefies structural properties of half-precision GEMM computation
 igemm_epilogue.hDefines the epilogue phase of the GEMM computation for IGEMM, supporting integer and floating-point output matrix formats
 igemm_global_tile.hImplements tile iterators to partition the thread block tile into 2D subtiles and efficiently load each. Applies permute transformation to construct 'interleaved K-strided' data layout in which 4-element dot products from the same K index are arranged in consecutive locations within shared memory
 igemm_multiply_add.hImplements matrix multiply accumulate operation of 8-bit integer data using DP4A instruction
 igemm_swizzle.hTransposes a fragment of data containing packed 8-bit integer elements
 igemm_traits.hDefies structural properties of mixed-precision integer GEMM. Multiplicands are assumed to be packed 8bit integers, accumulators are assumed to be 32b signed integers, and output formats vary
 iterator_access.hFree functions for loading and storing to implementations of tile iteartor concepts
 kernel_launch.hDefines structures and helpers to launch CUDA kernels within CUTLASS
 linear_scaling.hImplements the BLAS linear scaling function alpha*AB + beta*C
 linear_scaling_device_ptr.hImplements the BLAS linear scaling function alpha*AB + beta*C
 load_store.hDefines abstractions for efficiently loading and storing vectors to memory
 matrix_traits.hDefines properties of matrices used to denote layout and operands to GEMM kernels
 numeric_types.h
 platform.hC++ features that may be otherwise unimplemented for CUDA device functions
 predicate_vector.hDefines container classes and iterators for managing a statically sized vector of boolean predicates
 reshape_tile.hDefines a type for restructuring a tile
 scalar_or_pointer.hImplements the BLAS linear scaling function alpha*AB + beta*C
 sgemm_traits.hDefies structural properties of single-precision GEMM
 shape.hDefines Shape implementing the Layout concept for representing a 4D hypercube of objects
 tensor_ref.hDefines a structure containing strides, bounds, and a pointer to tensor data
 tensor_ref_collection.hIntroduces TensorRefCollection concept and defines TensorRefBatch and TensorRefArray
 tensor_view.hDefines a structure containing strides and a pointer to tensor data
 thread_multiply_add.hTemplate implementing matrix multiply-add operations on fragments
 threadblock_swizzle.hDefies functors for mapping blockIdx to partitions of the GEMM computation
 tile_allocation.hDefines a fragment based on a Shape<> template
 tile_coord.hDefines a coordinate used for the CUTLASS 4-D tile structure
 tile_iterator.hDefines the Tile Traits concept and iterators for loading and storing to tiles efficiently
 tile_stream.hImplements the tile stream concept, composing an iterator with a transformation. Offers split-phase semantics, separating the initiation of an asynchronous memory operation with a fence forcing it to complete
 tile_traits_standard.hDefines tile traits for several tile partitioning arrangements of threads expected to achieve efficient streaming performance
 vector.hDefines a 1D vector of elements held in the registers of each thread
 wmma_gemm_epilogue_traits.hDefines structural properties of WMMA GEMM's epilogue phase
 wmma_gemm_global_tile.hDefines tile iterator traits for loading thread block-level tile from global memory
 wmma_gemm_multiply_add.hImplements warp-level matrix multiply-accumulate operation using CUDA WMMA API
 wmma_gemm_shared_tile.hDefines iterator traits for efficiently loading and storing fragment to and from shared memory, specialized for WMMA GEMM
 wmma_gemm_traits.hDefies structural properties of GEMM targeting WMMA API in CUDA
 wmma_matrix.hAbstractions for loading and storing matrices using the CUDA WMMA API
 zip_fragment.hModels a pair of fragments
 zip_tensor_ref.hDefines a structure containing a pair of TensorRef-like objects
 zip_tile_iterator.hConstructs an iterator that owns two tile iterator instances
diff --git a/docs/fp16__sgemm__multiply__add_8h.html b/docs/fp16__sgemm__multiply__add_8h.html new file mode 100644 index 000000000..deff050ef --- /dev/null +++ b/docs/fp16__sgemm__multiply__add_8h.html @@ -0,0 +1,111 @@ + + + + + + + +Cutlass: fp16_sgemm_multiply_add.h File Reference + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+ +
+
fp16_sgemm_multiply_add.h File Reference
+
+
+ +

Template implementing matrix multiply-add operations on fragments. +More...

+ +

Go to the source code of this file.

+ + + + + +

+Classes

struct  cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, float >
 Template performing matrix multiply-add operation within a thread. More...
 
+ + + + + +

+Namespaces

 cutlass
 
 cutlass::gemm
 
+
+ + + + diff --git a/docs/fp16__sgemm__multiply__add_8h_source.html b/docs/fp16__sgemm__multiply__add_8h_source.html new file mode 100644 index 000000000..efac04637 --- /dev/null +++ b/docs/fp16__sgemm__multiply__add_8h_source.html @@ -0,0 +1,107 @@ + + + + + + + +Cutlass: fp16_sgemm_multiply_add.h Source File + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+
+
fp16_sgemm_multiply_add.h
+
+
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/fragment.h"
32 namespace cutlass {
33 namespace gemm {
34 
36 
38 template <typename ThreadGemmShape_,
39  typename ThreadsPerWarp_>
40 struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, float> {
44  typedef ThreadGemmShape_ ThreadGemmShape;
48  typedef ThreadsPerWarp_ ThreadsPerWarp;
52  typedef half ScalarA;
56  typedef half ScalarB;
60  typedef float ScalarC;
63 
65  CUTLASS_DEVICE ThreadMultiplyAdd() {}
66 
68  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
69  FragmentB const& b,
70  Accumulators const& c,
71  Accumulators& d) {
72  for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
73  for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
74  d[j * AccumulatorsPerThread::kW + i] = static_cast<ScalarC>(a[i]) * static_cast<ScalarC>(b[j]) + c[j * AccumulatorsPerThread::kW + i];
75  }
76  }
77  }
78 };
79 
81 
82 } // namespace gemm
83 } // namespace cutlass
Definition: convert.h:33
+
CUTLASS_DEVICE ThreadMultiplyAdd()
Ctor.
Definition: fp16_sgemm_multiply_add.h:65
+
Fragment< ScalarB, AccumulatorsPerThread::kH > FragmentB
The fragment for B.
Definition: fp16_sgemm_multiply_add.h:58
+
Shape< A_::kD *B_::kD, A_::kH *B_::kH, A_::kW *B_::kW, A_::kC *B_::kC > Shape
Definition: shape.h:119
+
A template defining Fragment Concept.
Definition: fragment.h:99
+
ShapeMul< ThreadGemmShape, ThreadsPerWarp >::Shape AccumulatorsPerWarp
The number of accumulators per warp.
Definition: fp16_sgemm_multiply_add.h:50
+
Template implementing matrix multiply-add operations on fragments.
+
ThreadGemmShape_ ThreadGemmShape
The shape of a thread-leveel matrix multiply accumulate.
Definition: fp16_sgemm_multiply_add.h:44
+
CUTLASS_DEVICE void multiply_add(FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d)
Multiply : d = a*b + c.
Definition: fp16_sgemm_multiply_add.h:68
+
half ScalarA
The type for A. specialized to half.
Definition: fp16_sgemm_multiply_add.h:52
+
half ScalarB
The type for B. specialized to half.
Definition: fp16_sgemm_multiply_add.h:56
+
ThreadsPerWarp_ ThreadsPerWarp
The number of threads per warp.
Definition: fp16_sgemm_multiply_add.h:48
+
Fragment< ScalarA, AccumulatorsPerThread::kW > FragmentA
The fragment for A.
Definition: fp16_sgemm_multiply_add.h:54
+
float ScalarC
The type for C and D. specialized to float.
Definition: fp16_sgemm_multiply_add.h:60
+
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
+
Fragment< ScalarC, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW, 16 > Accumulators
The accumulators.
Definition: fp16_sgemm_multiply_add.h:62
+
ThreadGemmShape AccumulatorsPerThread
Aliased to "AccumulatorsPerThread" for compatibility. Expect to be renamed in CUTLASS v2...
Definition: fp16_sgemm_multiply_add.h:46
+
Template performing matrix multiply-add operation within a thread.
Definition: thread_multiply_add.h:44
+
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
+
Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: fp16_sgemm_multiply_add.h:42
+
+ + + + diff --git a/docs/fp16__sgemm__traits_8h.html b/docs/fp16__sgemm__traits_8h.html new file mode 100644 index 000000000..0691fbbfc --- /dev/null +++ b/docs/fp16__sgemm__traits_8h.html @@ -0,0 +1,117 @@ + + + + + + + +Cutlass: fp16_sgemm_traits.h File Reference + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+ +
+
fp16_sgemm_traits.h File Reference
+
+ + + + + diff --git a/docs/fp16__sgemm__traits_8h_source.html b/docs/fp16__sgemm__traits_8h_source.html new file mode 100644 index 000000000..b5f94457f --- /dev/null +++ b/docs/fp16__sgemm__traits_8h_source.html @@ -0,0 +1,104 @@ + + + + + + + +Cutlass: fp16_sgemm_traits.h Source File + + + + + + + + + + +
+
+ + + + + + +
+
Cutlass +
+
CUDA Templates for Linear Algebra Subroutines and Solvers
+
+
+ + + + + + + + +
+
+ + +
+ +
+ + +
+
+
+
fp16_sgemm_traits.h
+
+
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/gemm/gemm.h"
38 
39 namespace cutlass {
40 namespace gemm {
41 
43 
44 template <
46  typename OutputTile_,
48  typename ThreadGemmShape_,
50  typename ScalarA_,
52  typename ScalarB_,
54  typename ScalarC_,
56  typename ScalarD_,
58  int kScalarsPerLdgA_ = 1,
60  int kScalarsPerLdgB_ = 1>
61 struct Fp16SgemmConfig : public GemmConfig<
63  ScalarA_,
65  ScalarB_,
67  ScalarC_,
69  ScalarD_,
71  OutputTile_,
73  ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, ScalarA_, ScalarB_, float /*for sgemm accum is float*/>,
75  kScalarsPerLdgA_,
77  kScalarsPerLdgA_,
79  4,
81  kScalarsPerLdgB_,
83  kScalarsPerLdgB_,
85  4,
87  1,
89  4,
91  1,
93  2> {};
94 
96 
97 template <
99  MatrixLayout::Kind kLayoutA_,
101  MatrixLayout::Kind kLayoutB_,
103  typename OutputTile_ = Shape<8, 128, 128>,
105  typename ScalarA_ = half,
107  typename ScalarB_ = half,
109  typename ScalarC_ = half,
111  typename ScalarD_ = half,
113  typename Scalar_ = half,
115  typename EpilogueFunctor_ = LinearScaling<Scalar_, FragmentMultiplyAdd<Scalar_, float/*accumulator type*/> >,
117  typename ThreadGemmShape_ = Shape<8, 8, 8>,
119  int kScalarsPerLdgA_ = 1,
121  int kScalarsPerLdgB_ = 1,
123  typename Index_ = int,
125  typename GemmConfig_ =
126  Fp16SgemmConfig<OutputTile_,
127  ThreadGemmShape_,
128  ScalarA_,
129  ScalarB_,
130  ScalarC_,
131  ScalarD_,
132  kScalarsPerLdgA_,
133  kScalarsPerLdgB_>,
135  typename GemmEpilogueTraits_ =
138  // The layout for A.
139  kLayoutA_,
140  // The layout for B.
141  kLayoutB_,
142  // The config.
143  GemmConfig_,
144  // The epilogue.
145  GemmEpilogue<GemmEpilogueTraits_>,
146  // The index.
147  Index_> {};
148 
150 
151 } // namespace gemm
152 } // namespace cutlass
Definition: convert.h:33
+
Defines iterators for efficiently loading and storing to global memory.
+
Defines structural properties of complete GEMM computation.
+
Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
+
Defines iterators for efficiently loading and storing tiles to and from shared memory.
+
Definition: gemm_config.h:76
+
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
+
Definition: gemm_epilogue_traits.h:323
+
Definition: fp16_sgemm_traits.h:61
+
Kind
Enumeration defining fundamental contiguous layouts.
Definition: matrix_traits.h:159
+
Template implementing matrix multiply-add operations on fragments.
+
Functor to compute linear combination of fragments.
Definition: linear_scaling.h:51
+
Implements a software-pipelined efficient GEMM.
+
Defines structural properties of the GEMM epilogue.
+
Definition: fp16_sgemm_traits.h:137
+
Definition: gemm_traits.h:650
+
Definition: fragment_multiply_add.h:41
+
+ + + + diff --git a/docs/fragment_8h.html b/docs/fragment_8h.html index d97ac7b5a..687dfdc86 100644 --- a/docs/fragment_8h.html +++ b/docs/fragment_8h.html @@ -83,15 +83,15 @@ $(function() {

Defines Fragment, a statically-sized array for storing parts of matrices within a thread's registers. More...

#include <assert.h>
-#include <cutlass/shape.h>
-#include <cutlass/util/cutlass_math.h>
-#include <cutlass/vector.h>
+#include "cutlass/shape.h"
+#include "cutlass/util/cutlass_math.h"
+#include "cutlass/vector.h"

Go to the source code of this file.

- + @@ -116,7 +116,7 @@ Namespaces diff --git a/docs/fragment_8h_source.html b/docs/fragment_8h_source.html index 8006bbbdf..f7d236565 100644 --- a/docs/fragment_8h_source.html +++ b/docs/fragment_8h_source.html @@ -76,64 +76,66 @@ $(function() {
fragment.h
-Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <assert.h>
32 #include <cutlass/shape.h>
34 #include <cutlass/vector.h>
35 
36 namespace cutlass {
37 
39 
56 
73 
75 template <int kAlignment_>
76 struct StorageType {
77  typedef uint64_t Type;
78 };
79 template <>
80 struct StorageType<4> {
81  typedef uint32_t Type;
82 };
83 template <>
84 struct StorageType<2> {
85  typedef uint16_t Type;
86 };
87 template <>
88 struct StorageType<1> {
89  typedef uint8_t Type;
90 };
91 
93 
98 template <typename Element_, int kElements_, size_t kAlignment_ = 16>
99 struct Fragment : public AlignedStruct<kAlignment_> {
101  static_assert(kAlignment_ == 16 || kAlignment_ >= sizeof(Element_), "Alignment is too small");
103  static_assert(is_pow2<kAlignment_>::value, "Alignment must be a power of two");
104 
108  typedef Element_ Element;
110  static int const kElements = kElements_;
111 
113  CUTLASS_DEVICE void clear() {
114  // Avoid element-wise access for sub 32b element type
115  if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
116  uint64_t* ptr = reinterpret_cast<uint64_t*>(storage);
117  for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) {
118  ptr[i] = uint64_t(0);
119  }
120  } else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) {
121  uint32_t* ptr = reinterpret_cast<uint32_t*>(storage);
122  for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) {
123  ptr[i] = uint32_t(0);
124  }
125  } else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) {
126  uint16_t* ptr = reinterpret_cast<uint16_t*>(storage);
127  for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) {
128  ptr[i] = uint16_t(0);
129  }
130  } else {
131  for (int i = 0; i < kElements; ++i) {
132  storage[i] = 0;
133  }
134  }
135  }
136 
138  CUTLASS_DEVICE Element& operator[](int i) {
139  assert(i < kElements_);
140  return reinterpret_cast<Element*>(storage)[i];
141  }
142 
144  CUTLASS_DEVICE Element const& operator[](int i) const {
145  assert(i < kElements_);
146  return reinterpret_cast<Element const*>(storage)[i];
147  }
148 
149  private:
152 
154  static int const kStorageCount =
155  (sizeof(Element_) * kElements_ + sizeof(StorageType) - 1) / sizeof(StorageType);
157  StorageType storage[kStorageCount];
158 
160  static_assert(sizeof(StorageType) <= kAlignment_, "StorageType is too big for given alignment");
161 };
162 
164 
169 template <typename Fragment_, typename Iterations_, typename AccessType_>
174  typedef Fragment_ Fragment;
176  typedef Iterations_ Iterations;
178  typedef AccessType_ AccessType;
179 
181  typedef typename Fragment::Element Element;
183  static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
188 
190  template <typename OtherFragment_>
191  CUTLASS_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
192  : pointer(reinterpret_cast<Element*>(&fragment[offset])) {
193  static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
194  }
195 
197  CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
198  int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
199  return reinterpret_cast<AccessType const&>(pointer[imm]);
200  }
201 
203  CUTLASS_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
204  int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
205  return reinterpret_cast<AccessType&>(pointer[imm]);
206  }
207 
209  CUTLASS_DEVICE AccessType const& operator[](int i) const {
210  return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
211  }
212 
214  CUTLASS_DEVICE AccessType& operator[](int i) {
215  return reinterpret_cast<AccessType&>(pointer[i * kElementsPerAccess]);
216  }
217 
219  CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
220 
223 };
224 
226 
227 template <typename Fragment_, typename Iterations_, typename AccessType_>
232  typedef Fragment_ Fragment;
234  typedef Iterations_ Iterations;
236  typedef AccessType_ AccessType;
237 
239  typedef typename Fragment::Element Element;
241  static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
246 
248  template <typename OtherFragment_>
249  CUTLASS_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
250  : pointer(reinterpret_cast<Element const*>(&fragment[offset])) {
251  static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
252  }
254  CUTLASS_DEVICE FragmentConstIterator(
256  : pointer(reinterpret_cast<Element const*>(rhs_.offset)) {}
257 
259  CUTLASS_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
260  int const imm = ComputeOffsetFromStrides<IterationsStrides>::get(d, h, w, c);
261  return reinterpret_cast<AccessType const&>(pointer[imm]);
262  }
263 
265  CUTLASS_DEVICE AccessType const& operator[](int i) const {
266  return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
267  }
268 
270  CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
271 
273  Element const* pointer;
274 };
275 
277 
278 } // namespace cutlass
CUTLASS_DEVICE void clear()
Clear a fragment.
Definition: fragment.h:113
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <assert.h>
32 #include "cutlass/shape.h"
34 #include "cutlass/vector.h"
35 
36 namespace cutlass {
37 
39 
56 
73 
75 template <int alignment>
76 struct StorageType {
77  typedef uint64_t Type;
78 };
79 template <>
80 struct StorageType<4> {
81  typedef uint32_t Type;
82 };
83 template <>
84 struct StorageType<2> {
85  typedef uint16_t Type;
86 };
87 template <>
88 struct StorageType<1> {
89  typedef uint8_t Type;
90 };
91 
93 
98 template <typename Element_, int kElements_, size_t kAlignment_ = 16>
99 struct Fragment : public AlignedStruct<kAlignment_> {
101  static_assert(kAlignment_ == 16 || kAlignment_ >= sizeof(Element_), "Alignment is too small");
103  static_assert(is_pow2<kAlignment_>::value, "Alignment must be a power of two");
104 
108  typedef Element_ Element;
110  static int const kElements = kElements_;
112  static int const kAlignment = kAlignment_;
113 
116  // Avoid element-wise access for sub 32b element type
117  if (kAlignment_ >= 8 && (kElements * sizeof(Element)) % 8 == 0) {
118  uint64_t* ptr = reinterpret_cast<uint64_t*>(storage);
119  for (int i = 0; i < (kElements * sizeof(Element)) / 8; ++i) {
120  ptr[i] = uint64_t(0);
121  }
122  } else if (kAlignment_ >= 4 && (kElements * sizeof(Element)) % 4 == 0) {
123  uint32_t* ptr = reinterpret_cast<uint32_t*>(storage);
124  for (int i = 0; i < (kElements * sizeof(Element)) / 4; ++i) {
125  ptr[i] = uint32_t(0);
126  }
127  } else if (kAlignment_ >= 2 && (kElements * sizeof(Element)) % 2 == 0) {
128  uint16_t* ptr = reinterpret_cast<uint16_t*>(storage);
129  for (int i = 0; i < (kElements * sizeof(Element)) / 2; ++i) {
130  ptr[i] = uint16_t(0);
131  }
132  } else {
133  for (int i = 0; i < kElements; ++i) {
134  storage[i] = 0;
135  }
136  }
137  }
138 
140  CUTLASS_HOST_DEVICE Element& operator[](int i) { return reinterpret_cast<Element*>(storage)[i]; }
141 
143  CUTLASS_HOST_DEVICE Element const& operator[](int i) const {
144  return reinterpret_cast<Element const*>(storage)[i];
145  }
146 
147  private:
150 
152  static int const kStorageCount =
153  (sizeof(Element_) * kElements_ + sizeof(StorageType) - 1) / sizeof(StorageType);
155  StorageType storage[kStorageCount];
156 
158  static_assert(sizeof(StorageType) <= kAlignment_, "StorageType is too big for given alignment");
159 };
160 
162 
167 template <typename Fragment_, typename Iterations_, typename AccessType_>
172  typedef Fragment_ Fragment;
174  typedef Iterations_ Iterations;
176  typedef AccessType_ AccessType;
177 
179  typedef typename Fragment::Element Element;
181  static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
186 
188  template <typename OtherFragment_>
189  CUTLASS_HOST_DEVICE FragmentIterator(OtherFragment_& fragment, int offset = 0)
190  : pointer(reinterpret_cast<Element*>(&fragment[offset])) {
191  static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
192  }
193 
195  CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
196  int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
197  return reinterpret_cast<AccessType const&>(pointer[imm]);
198  }
199 
201  CUTLASS_HOST_DEVICE AccessType& at(int d, int h, int w, int c = 0) {
202  int const imm = ComputeOffsetFromStrides<Strides>::get(d, h, w, c);
203  return reinterpret_cast<AccessType&>(pointer[imm]);
204  }
205 
208  return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
209  }
210 
213  return reinterpret_cast<AccessType&>(pointer[i * kElementsPerAccess]);
214  }
215 
217  CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
218 
221 };
222 
224 
225 template <typename Fragment_, typename Iterations_, typename AccessType_>
230  typedef Fragment_ Fragment;
232  typedef Iterations_ Iterations;
234  typedef AccessType_ AccessType;
235 
237  typedef typename Fragment::Element Element;
239  static int const kElementsPerAccess = (int)(sizeof(AccessType) / sizeof(Element));
244 
246  template <typename OtherFragment_>
247  CUTLASS_HOST_DEVICE FragmentConstIterator(OtherFragment_& fragment, int offset = 0)
248  : pointer(reinterpret_cast<Element const*>(&fragment[offset])) {
249  static_assert(OtherFragment_::kElements >= Fragment::kElements, "");
250  }
254  : pointer(reinterpret_cast<Element const*>(rhs_.offset)) {}
255 
257  CUTLASS_HOST_DEVICE AccessType const& at(int d, int h, int w, int c = 0) const {
258  int const imm = ComputeOffsetFromStrides<IterationsStrides>::get(d, h, w, c);
259  return reinterpret_cast<AccessType const&>(pointer[imm]);
260  }
261 
264  return reinterpret_cast<AccessType const&>(pointer[i * kElementsPerAccess]);
265  }
266 
268  CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const { return true; }
269 
271  Element const* pointer;
272 };
273 
275 
276 } // namespace cutlass
CUTLASS_HOST_DEVICE void clear()
Clear a fragment.
Definition: fragment.h:115
+
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: fragment.h:217
Definition: convert.h:33
-
CUTLASS_DEVICE Element & operator[](int i)
The accessor.
Definition: fragment.h:138
-
CUTLASS_DEVICE AccessType & at(int d, int h, int w, int c=0)
The accessor.
Definition: fragment.h:203
-
Definition: vector.h:41
-
Definition: fragment.h:228
-
CUTLASS_DEVICE AccessType const & operator[](int i) const
The accessor.
Definition: fragment.h:265
-
Shape< Shape_::kH *Shape_::kW *Shape_::kC, Shape_::kW *Shape_::kC, Shape_::kC, 1 > Shape
Definition: shape.h:155
+
Shape< Shape_::kH *Shape_::kW *Shape_::kC, Shape_::kW *Shape_::kC, Shape_::kC, elementsPerAccess > Shape
Definition: shape.h:170
+
Definition: vector.h:42
+
Definition: fragment.h:226
+
CUTLASS_HOST_DEVICE FragmentIterator(OtherFragment_ &fragment, int offset=0)
Ctor.
Definition: fragment.h:189
A template defining Fragment Concept.
Definition: fragment.h:99
-
Fragment::Element Element
The element.
Definition: fragment.h:181
-
static int const kElementsPerAccess
The number of elements per access.
Definition: fragment.h:241
-
Fragment_ Fragment
The fragment.
Definition: fragment.h:174
-
Fragment_ Fragment
The fragment.
Definition: fragment.h:232
-
CUTLASS_DEVICE AccessType & operator[](int i)
The accessor.
Definition: fragment.h:214
-
Fragment::Element Element
The element.
Definition: fragment.h:239
-
ShapeStrides< FragmentShape >::Shape IterationsStrides
The linear strides for iterations.
Definition: fragment.h:245
-
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: fragment.h:270
-
CUTLASS_DEVICE FragmentIterator(OtherFragment_ &fragment, int offset=0)
Ctor.
Definition: fragment.h:191
+
Fragment::Element Element
The element.
Definition: fragment.h:179
+
static int const kElementsPerAccess
The number of elements per access.
Definition: fragment.h:239
+
Fragment_ Fragment
The fragment.
Definition: fragment.h:172
+
Fragment_ Fragment
The fragment.
Definition: fragment.h:230
+
Fragment::Element Element
The element.
Definition: fragment.h:237
Fragment< Element_, kElements_ > This_
Make sure the alignment makes sense wrt the size of elements.
Definition: fragment.h:101
-
FragmentIterator< Fragment_, Iterations_, AccessType_ > This_
This class.
Definition: fragment.h:172
-
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:243
+
FragmentIterator< Fragment_, Iterations_, AccessType_ > This_
This class.
Definition: fragment.h:170
+
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:241
Math utilities.
Definition: fragment.h:76
uint32_t Type
Definition: fragment.h:81
uint8_t Type
Definition: fragment.h:89
-
static CUTLASS_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:211
-
Element * pointer
The pointer.
Definition: fragment.h:222
-
AccessType_ AccessType
The access type.
Definition: fragment.h:236
+
Element * pointer
The pointer.
Definition: fragment.h:220
+
CUTLASS_HOST_DEVICE Element const & operator[](int i) const
The accessor.
Definition: fragment.h:143
+
AccessType_ AccessType
The access type.
Definition: fragment.h:234
+
ShapeStrides< FragmentShape, kElementsPerAccess >::Shape IterationsStrides
The linear strides for iterations.
Definition: fragment.h:243
Definition: shape.h:118
-
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:185
-
A template defining Fragment Iterator Concept.
Definition: fragment.h:170
+
ShapeMul< Iterations, Shape< 1, 1, 1, kElementsPerAccess > >::Shape FragmentShape
The shape of the the fragment.
Definition: fragment.h:183
+
CUTLASS_HOST_DEVICE FragmentConstIterator(OtherFragment_ &fragment, int offset=0)
Ctor.
Definition: fragment.h:247
+
A template defining Fragment Iterator Concept.
Definition: fragment.h:168
static int const kElements
The number of elements.
Definition: fragment.h:110
-
CUTLASS_DEVICE Element const & operator[](int i) const
The accessor.
Definition: fragment.h:144
-
Iterations_ Iterations
The number of iterations.
Definition: fragment.h:234
-
#define static_assert(__e, __m)
Definition: platform.h:145
-
Iterations_ Iterations
The number of iterations.
Definition: fragment.h:176
+
#define CUTLASS_HOST_DEVICE
Definition: cutlass.h:46
+
Iterations_ Iterations
The number of iterations.
Definition: fragment.h:232
+
CUTLASS_HOST_DEVICE AccessType const & at(int d, int h, int w, int c=0) const
The accessor.
Definition: fragment.h:195
+
#define static_assert(__e, __m)
Definition: platform.h:153
+
Iterations_ Iterations
The number of iterations.
Definition: fragment.h:174
+
CUTLASS_HOST_DEVICE FragmentConstIterator(FragmentIterator< Fragment_, Iterations_, AccessType_ > const &rhs_)
Create from non-constant FragmentIterator.
Definition: fragment.h:252
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
-
CUTLASS_DEVICE AccessType const & at(int d, int h, int w, int c=0) const
The accessor.
Definition: fragment.h:259
Element_ Element
The element.
Definition: fragment.h:108
-
FragmentIterator< Fragment_, Iterations_, AccessType_ > This_
This class.
Definition: fragment.h:230
-
CUTLASS_DEVICE AccessType const & operator[](int i) const
The accessor.
Definition: fragment.h:209
+
FragmentIterator< Fragment_, Iterations_, AccessType_ > This_
This class.
Definition: fragment.h:228
+
CUTLASS_HOST_DEVICE AccessType const & operator[](int i) const
The accessor.
Definition: fragment.h:263
+
CUTLASS_HOST_DEVICE Element & operator[](int i)
The accessor.
Definition: fragment.h:140
+
CUTLASS_HOST_DEVICE AccessType const & operator[](int i) const
The accessor.
Definition: fragment.h:207
uint16_t Type
Definition: fragment.h:85
Defines a 1D vector of elements held in the registers of each thread.
-
CUTLASS_DEVICE FragmentConstIterator(FragmentIterator< Fragment_, Iterations_, AccessType_ > const &rhs_)
Create from non-constant FragmentIterator.
Definition: fragment.h:254
-
static int const kElementsPerAccess
The number of elements per access.
Definition: fragment.h:183
-
ShapeStrides< FragmentShape >::Shape Strides
The linear strides for iterations.
Definition: fragment.h:187
+
uint64_t Type
Definition: fragment.h:77
+
CUTLASS_HOST_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: fragment.h:268
+
ShapeStrides< FragmentShape, kElementsPerAccess >::Shape Strides
The linear strides for iterations.
Definition: fragment.h:185
+
static CUTLASS_HOST_DEVICE int get(int d, int h, int w, int c)
Definition: shape.h:199
+
CUTLASS_HOST_DEVICE AccessType & operator[](int i)
The accessor.
Definition: fragment.h:212
+
CUTLASS_HOST_DEVICE AccessType & at(int d, int h, int w, int c=0)
The accessor.
Definition: fragment.h:201
+
static int const kElementsPerAccess
The number of elements per access.
Definition: fragment.h:181
Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
-
AccessType_ AccessType
The access type.
Definition: fragment.h:178
-
CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const
Is the iterator valid?
Definition: fragment.h:219
-
uint64_t Type
Definition: fragment.h:77
+
AccessType_ AccessType
The access type.
Definition: fragment.h:176
+
static int const kAlignment
Alignment.
Definition: fragment.h:112
Definition: cutlass_math.h:45
-
CUTLASS_DEVICE FragmentConstIterator(OtherFragment_ &fragment, int offset=0)
Ctor.
Definition: fragment.h:249
-
CUTLASS_DEVICE AccessType const & at(int d, int h, int w, int c=0) const
The accessor.
Definition: fragment.h:197
-
Element const * pointer
The pointer.
Definition: fragment.h:273
+
CUTLASS_HOST_DEVICE AccessType const & at(int d, int h, int w, int c=0) const
The accessor.
Definition: fragment.h:257
+
Element const * pointer
The pointer.
Definition: fragment.h:271
diff --git a/docs/fragment__multiply__add_8h.html b/docs/fragment__multiply__add_8h.html index 59a94dfdf..107cfee79 100644 --- a/docs/fragment__multiply__add_8h.html +++ b/docs/fragment__multiply__add_8h.html @@ -82,15 +82,15 @@ $(function() {

Defines multiply-add operations on fragments within a thread. More...

-

Classes

struct  cutlass::StorageType< kAlignment_ >
struct  cutlass::StorageType< alignment >
 
struct  cutlass::StorageType< 4 >
 
- + - +

Classes

struct  cutlass::gemm::FragmentMultiplyAdd< Scalar_ >
struct  cutlass::gemm::FragmentMultiplyAdd< ScalarAlphaBeta_, ScalarAccum_, fragMul2 >
 
struct  cutlass::gemm::FragmentMultiplyAdd< half >
struct  cutlass::gemm::FragmentMultiplyAdd< half, half, true >
 
diff --git a/docs/fragment__multiply__add_8h_source.html b/docs/fragment__multiply__add_8h_source.html index 9b453fd94..1d4c4f7f2 100644 --- a/docs/fragment__multiply__add_8h_source.html +++ b/docs/fragment__multiply__add_8h_source.html @@ -76,28 +76,26 @@ $(function() {
fragment_multiply_add.h
-Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include <cutlass/fragment.h>
31 
32 namespace cutlass {
33 namespace gemm {
34 
36 
37 template <typename Scalar_>
42  typedef Scalar_ ScalarA;
44  typedef Scalar_ ScalarB;
46  typedef Scalar_ ScalarC;
47 
49  CUTLASS_DEVICE FragmentMultiplyAdd() {}
50 
52  template <typename Fragment_>
53  CUTLASS_DEVICE void multiply(Scalar_ a, Fragment_ const& b, Fragment_& d) {
54  for (int j = 0; j < Fragment_::kElements; ++j) {
55  d[j] = a * b[j];
56  }
57  }
58 
60  template <typename Fragment_>
61  CUTLASS_DEVICE void multiply_add(Scalar_ a,
62  Fragment_ const& b,
63  Fragment_ const& c,
64  Fragment_& d) {
65  for (int j = 0; j < Fragment_::kElements; ++j) {
66  d[j] = a * b[j] + c[j];
67  }
68  }
69 };
70 
72 
73 #if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
74 template <>
75 struct FragmentMultiplyAdd<half> {
79  typedef half ScalarA;
81  typedef half ScalarB;
83  typedef half ScalarC;
84 
86  CUTLASS_DEVICE FragmentMultiplyAdd() {}
87 
89  template <typename Fragment_>
90  CUTLASS_DEVICE void multiply(half a, Fragment_ const& b, Fragment_& d) {
91 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
92  // The input.
93  __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
94  // The output.
95  __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
96 
97  // Assemble a half2 from a.
98  __half2 const a_half2 = __half2half2(a);
99 
100  for (int i = 0; i < Fragment_::kElements / 2; ++i) {
101  d_half2[i] = __hmul2(a_half2, b_half2[i]);
102  }
103 #endif
104  }
105 
107  template <typename Fragment_>
108  CUTLASS_DEVICE void multiply_add(half a, Fragment_ const& b, Fragment_ const& c, Fragment_& d) {
109 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
110  // The inputs.
111  __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
112  __half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
113  // The output.
114  __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
115 
116  // Assemble a half2 from a.
117  __half2 const a_half2 = __half2half2(a);
118 
119  for (int i = 0; i < Fragment_::kElements / 2; ++i) {
120  d_half2[i] = __hfma2(a_half2, b_half2[i], c_half2[i]);
121  }
122 #endif
123  }
124 };
125 
126 #endif
127 
129 
130 } // namespace gemm
131 } // namespace cutlass
Scalar_ ScalarB
The type for B.
Definition: fragment_multiply_add.h:44
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
28 #pragma once
29 
30 #include "cutlass/fragment.h"
31 
32 namespace cutlass {
33 namespace gemm {
34 
36 
37 template < typename ScalarAlphaBeta_,
38  typename ScalarAccum_,
39  bool fragMul2 = true /*number of element per fragment is multiple of 2*/
40 >
45  typedef ScalarAlphaBeta_ ScalarAlphaBeta;
47  typedef ScalarAccum_ ScalarAccum;
48 
50  CUTLASS_DEVICE FragmentMultiplyAdd() {}
51 
53  template <typename FragmentB_, typename FragmentCd_>
54  CUTLASS_DEVICE void multiply(ScalarAlphaBeta a, FragmentB_ const& b, FragmentCd_& d) {
55 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
56  int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
57  for (int j = 0; j < FragmentCd_::kElements; ++j) {
58  d[j] = b[j * kReduction + 0];
59  for (int k = 1; k < kReduction; ++k) {
60  d[j] += b[j * kReduction + k];
61  }
62  d[j] = a * ScalarAlphaBeta(d[j]);
63  }
64 #endif
65  }
66 
68  template <typename FragmentB_, typename FragmentCd_>
69  CUTLASS_DEVICE void multiply_add(ScalarAlphaBeta a,
70  FragmentB_ const& b,
71  FragmentCd_ const& c,
72  FragmentCd_& d) {
73 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
74  int const kReduction = FragmentB_::kElements / FragmentCd_::kElements;
75  for (int j = 0; j < FragmentCd_::kElements; ++j) {
76  d[j] = b[j * kReduction + 0];
77  for (int k = 1; k < kReduction; ++k) {
78  d[j] += b[j * kReduction + k];
79  }
80  d[j] = a * ScalarAlphaBeta(d[j]) + ScalarAlphaBeta(c[j]);
81  }
82 #endif
83  }
84 };
85 
87 
88 #if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
89 template <>
90 struct FragmentMultiplyAdd<half, half, true> {
94  typedef half ScalarAlphaBeta;
96  typedef half ScalarAccum;
97 
99  CUTLASS_DEVICE FragmentMultiplyAdd() {}
100 
102  template <typename FragmentB_, typename FragmentCd_>
103  CUTLASS_DEVICE void multiply(half a, FragmentB_ const& b, FragmentCd_& d) {
104 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
105  // The input.
106  __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
107  // The output.
108  __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
109 
110  // Assemble a half2 from a.
111  __half2 const a_half2 = __half2half2(a);
112 
113  int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
114 
115  for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
116  d_half2[j] = __hmul2(a_half2, b_half2[j * kReduction + 0]);
117 
118  for (int k = 1; k < kReduction; ++k) {
119  d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
120  }
121  }
122 #endif
123  }
124 
125 
127  template <typename FragmentB_, typename FragmentCd_>
128  CUTLASS_DEVICE void multiply_add(half a,
129  FragmentB_ const& b,
130  FragmentCd_ const& c,
131  FragmentCd_& d) {
132 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
133  // The inputs.
134  __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
135  __half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
136  // The output.
137  __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
138 
139  // Assemble a half2 from a.
140  __half2 const a_half2 = __half2half2(a);
141 
142  int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements);
143  for (int j = 0; j < FragmentCd_::kElements / 2; ++j) {
144  d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + 0], c_half2[j]);
145 
146  for (int k = 1; k < kReduction; ++k) {
147  d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]);
148  }
149  }
150 #endif
151  }
152 };
153 
154 #endif
155 
157 
158 } // namespace gemm
159 } // namespace cutlass
CUTLASS_DEVICE void multiply(ScalarAlphaBeta a, FragmentB_ const &b, FragmentCd_ &d)
Multiply : d = a*b.
Definition: fragment_multiply_add.h:54
+
Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: fragment_multiply_add.h:92
Definition: convert.h:33
-
CUTLASS_DEVICE void multiply(Scalar_ a, Fragment_ const &b, Fragment_ &d)
Multiply : d = a*b.
Definition: fragment_multiply_add.h:53
-
half ScalarA
The type for A.
Definition: fragment_multiply_add.h:79
-
CUTLASS_DEVICE FragmentMultiplyAdd()
Ctor.
Definition: fragment_multiply_add.h:86
-
CUTLASS_DEVICE void multiply_add(Scalar_ a, Fragment_ const &b, Fragment_ const &c, Fragment_ &d)
Multiply : d = a*b + c.
Definition: fragment_multiply_add.h:61
-
half ScalarC
The type for C and D.
Definition: fragment_multiply_add.h:83
-
CUTLASS_DEVICE void multiply_add(half a, Fragment_ const &b, Fragment_ const &c, Fragment_ &d)
Multiply : d = a*b + c.
Definition: fragment_multiply_add.h:108
+
half ScalarAlphaBeta
The type for alpha and beta.
Definition: fragment_multiply_add.h:94
+
CUTLASS_DEVICE FragmentMultiplyAdd()
Ctor.
Definition: fragment_multiply_add.h:50
+
CUTLASS_DEVICE FragmentMultiplyAdd()
Ctor.
Definition: fragment_multiply_add.h:99
+
CUTLASS_DEVICE void multiply(half a, FragmentB_ const &b, FragmentCd_ &d)
Multiply : d = a*b.
Definition: fragment_multiply_add.h:103
+
ScalarAccum_ ScalarAccum
The type for accumlator.
Definition: fragment_multiply_add.h:47
A Shape implementing Layout Concept describing the dimensions of a cube.
Definition: shape.h:64
-
Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: fragment_multiply_add.h:40
-
Scalar_ ScalarC
The type for C and D.
Definition: fragment_multiply_add.h:46
-
Scalar_ ScalarA
The type for A.
Definition: fragment_multiply_add.h:42
-
CUTLASS_DEVICE FragmentMultiplyAdd()
Ctor.
Definition: fragment_multiply_add.h:49
+
ScalarAlphaBeta_ ScalarAlphaBeta
The type for alpha and beta.
Definition: fragment_multiply_add.h:45
+
CUTLASS_DEVICE void multiply_add(half a, FragmentB_ const &b, FragmentCd_ const &c, FragmentCd_ &d)
Multiply : d = a*b + c.
Definition: fragment_multiply_add.h:128
+
Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: fragment_multiply_add.h:43
Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
-
CUTLASS_DEVICE void multiply(half a, Fragment_ const &b, Fragment_ &d)
Multiply : d = a*b.
Definition: fragment_multiply_add.h:90
-
Shape< 1, 1, 1, 1 > InstructionShape
The shape of the instruction.
Definition: fragment_multiply_add.h:77
-
half ScalarB
The type for B.
Definition: fragment_multiply_add.h:81
-
Definition: fragment_multiply_add.h:38
+
half ScalarAccum
The type for accumlator.
Definition: fragment_multiply_add.h:96
+
CUTLASS_DEVICE void multiply_add(ScalarAlphaBeta a, FragmentB_ const &b, FragmentCd_ const &c, FragmentCd_ &d)
Multiply : d = a*b + c.
Definition: fragment_multiply_add.h:69
+
Definition: fragment_multiply_add.h:41
diff --git a/docs/functions.html b/docs/functions.html index e6b156fbc..bdde612a8 100644 --- a/docs/functions.html +++ b/docs/functions.html @@ -71,77 +71,101 @@ $(function() {
Here is a list of all class members with links to the classes they belong to:

- a -

diff --git a/docs/functions_0x7e.html b/docs/functions_0x7e.html index 41aa664c4..0cb0e3458 100644 --- a/docs/functions_0x7e.html +++ b/docs/functions_0x7e.html @@ -78,7 +78,7 @@ $(function() { diff --git a/docs/functions_b.html b/docs/functions_b.html index 79038aa18..35c2018ba 100644 --- a/docs/functions_b.html +++ b/docs/functions_b.html @@ -71,42 +71,82 @@ $(function() {
Here is a list of all class members with links to the classes they belong to:

- b -

diff --git a/docs/functions_enum.html b/docs/functions_enum.html index b710de0fe..df8ae39f4 100644 --- a/docs/functions_enum.html +++ b/docs/functions_enum.html @@ -70,18 +70,20 @@ $(function() { diff --git a/docs/functions_eval.html b/docs/functions_eval.html index 40c01ec85..b53129338 100644 --- a/docs/functions_eval.html +++ b/docs/functions_eval.html @@ -77,6 +77,13 @@ $(function() { +

- b -

+ +

- k -

@@ -139,6 +160,13 @@ $(function() { +

- o -

+ +

- v -

diff --git a/docs/functions_func_g.html b/docs/functions_func_g.html index b30237366..b6258df4a 100644 --- a/docs/functions_func_g.html +++ b/docs/functions_func_g.html @@ -74,47 +74,73 @@ $(function() {
  • Gemm() : cutlass::gemm::Gemm< GemmTraits_ >
  • +
  • GemmCoord() +: cutlass::gemm::GemmCoord +
  • +
  • GemmDesc() +: cutlass::gemm::GemmDesc< AType_, BType_, CType_, DType_, SType_, Index_ > +
  • GemmEpilogue() -: cutlass::gemm::GemmEpilogue< GemmEpilogueTraits_ > +: cutlass::gemm::GemmEpilogue< GemmEpilogueTraits_ >
  • GemmGlobalIteratorAb() -: cutlass::gemm::GemmGlobalIteratorAb< TileTraits_, Index_ > +: cutlass::gemm::GemmGlobalIteratorAb< TileTraits_, Index_ >
  • GemmGlobalIteratorCd() -: cutlass::gemm::GemmGlobalIteratorCd< TileTraits_, Index_ > +: cutlass::gemm::GemmGlobalIteratorCd< TileTraits_, Index_ >
  • get() -: cutlass::ComputeOffsetFromShape< Shape_ > -, cutlass::ComputeOffsetFromShape< Shape< 1, kSh_, kSw_, 1 > > -, cutlass::ComputeOffsetFromShape< Shape< 1, kSh_, kSw_, kSc_ > > -, cutlass::ComputeOffsetFromStrides< Strides_ > -, cutlass::ComputeOffsetFromStrides< Shape< 1, S_h_, S_w_, 1 > > -, cutlass::ComputeOffsetFromStrides< Shape< 1, S_h_, S_w_, S_c_ > > +: cutlass::ComputeOffsetFromShape< Shape_ > +, cutlass::ComputeOffsetFromStrides< Strides_ > , cutlass::ComputeThreadOffsetFromStrides< Threads_, Strides_ > , cutlass::ComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, 1 >, Shape< 1, S_h_, S_w_, 1 > > , cutlass::ComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, T_c_ >, Shape< 1, S_h_, S_w_, S_c_ > > +, cutlass::detail::ScalarOrPointer< Scalar_ > , cutlass::platform::unique_ptr< T, Deleter > , cutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::Iterator
  • +
  • get_batch_id() +: cutlass::gemm::ColumnMajorBlockSwizzle< groupCols, swDirection > +, cutlass::gemm::IdentityBlockSwizzle +, cutlass::gemm::RowMajorBlockSwizzle< groupRows, swDirection > +
  • get_deleter() -: cutlass::platform::unique_ptr< T, Deleter > +: cutlass::platform::unique_ptr< T, Deleter > +
  • +
  • get_grid_layout() +: cutlass::gemm::ColumnMajorBlockSwizzle< groupCols, swDirection > +, cutlass::gemm::IdentityBlockSwizzle +, cutlass::gemm::RowMajorBlockSwizzle< groupRows, swDirection > +
  • +
  • get_pointer_offset() +: cutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > +
  • +
  • get_ptr() +: cutlass::detail::ScalarOrPointer< Scalar_ > +
  • +
  • get_scalar() +: cutlass::detail::ScalarOrPointer< Scalar_ > +
  • +
  • get_threadblock_offset() +: cutlass::gemm::ColumnMajorBlockSwizzle< groupCols, swDirection > +, cutlass::gemm::IdentityBlockSwizzle +, cutlass::gemm::RowMajorBlockSwizzle< groupRows, swDirection >
  • GlobalLoadStream() -: cutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >::GlobalLoadStream -, cutlass::gemm::GlobalLoadStream< LoadIterator_, StoreIterator_, Transformer_ > +: cutlass::gemm::GlobalLoadStream< Operand, LoadIterator_, StoreIterator_, Transformer_ >
  • -
  • GlobalLoadStreamBase() -: cutlass::gemm::GlobalLoadStreamBase< LoadIterator_, StoreIterator_, Transformer_ > +
  • GlobalLoadStreamPair() +: cutlass::gemm::GlobalLoadStreamPair< StreamA_, StreamB_, kResidueInProlog_ >
  • good() -: cutlass::TensorRef< Storage_, Rank_ > -, cutlass::TensorView< T > +: cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > +, cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >
  • diff --git a/docs/functions_func_h.html b/docs/functions_func_h.html index 7eb85aac8..184c449e2 100644 --- a/docs/functions_func_h.html +++ b/docs/functions_func_h.html @@ -71,14 +71,23 @@ $(function() {  

    - h -

    diff --git a/docs/functions_func_i.html b/docs/functions_func_i.html index 16cfdc518..7c8680aba 100644 --- a/docs/functions_func_i.html +++ b/docs/functions_func_i.html @@ -72,90 +72,103 @@ $(function() {

    - i -

    diff --git a/docs/functions_func_k.html b/docs/functions_func_k.html new file mode 100644 index 000000000..ced1e9b63 --- /dev/null +++ b/docs/functions_func_k.html @@ -0,0 +1,98 @@ + + + + + + + +Cutlass: Class Members - Functions + + + + + + + + + + +
    +
    +

    @@ -103,7 +103,7 @@ Namespaces

    + + + + + +
    +
    Cutlass +
    +
    CUDA Templates for Linear Algebra Subroutines and Solvers
    +
    + + + + + + + + + + +
    +
    + + +
    + +
    + +
    +  + +

    - k -

    +
    + + + + diff --git a/docs/functions_func_l.html b/docs/functions_func_l.html index c76f9fc53..76d84a054 100644 --- a/docs/functions_func_l.html +++ b/docs/functions_func_l.html @@ -74,30 +74,56 @@ $(function() {
  • launch() : cutlass::gemm::Gemm< GemmTraits_ >
  • +
  • Launch() +: cutlass::gemm::Launch< Gemm, WithLaunchBounds > +, cutlass::gemm::Launch< Gemm, false > +
  • leading_dim() -: cutlass::TensorRef< Storage_, Rank_ > +: cutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ > +, cutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >
  • LinearScaling() -: cutlass::gemm::LinearScaling< Scalar_, FragmentMultiplyAdd_ > +: cutlass::gemm::LinearScaling< Scalar_, FragmentMultiplyAdd_ > +
  • +
  • LinearScalingDevicePtr() +: cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >
  • load() -: cutlass::FragmentLoad< IteratorFragment::kScalar, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride > -, cutlass::FragmentLoad< IteratorFragment::kWmmaMatrix, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride > -, cutlass::Load< Scalar_, Lanes_, Memory_, bool, size_t > -, cutlass::Load< double, 2, Memory_, true, 16 > -, cutlass::Load< Scalar_, Lanes_, Memory_, true, 16 > -, cutlass::Load< Scalar_, Lanes_, Memory_, true, 4 > -, cutlass::Load< Scalar_, Lanes_, Memory_, true, 8 > -, cutlass::TileLoadIterator< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, IteratorFragment_, Skew_ > +: cutlass::Load< Scalar_, kAccessSize, Memory_, kFragmentElementType, FragmentElement_, kStride, size > +, cutlass::Load< double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16 > +, cutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2 > +, cutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16 > +, cutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4 > +, cutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8 > +, cutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size > +, cutlass::Load< Vector< bin1_t, 32 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size > +, cutlass::Load< Vector< int4_t, 8 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size > +, cutlass::Load< Vector< uint4_t, 8 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size > +, cutlass::TileLoadIterator< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > +, cutlass::TileStoreIterator< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > +, cutlass::ZipTileIterator< First_, Second_ > +
  • +
  • load_element() +: cutlass::gemm::GemmGlobalIteratorAb< TileTraits_, Index_ > +, cutlass::gemm::GemmGlobalIteratorCd< TileTraits_, Index_ > +, cutlass::gemm::IgemmGlobalIteratorAb< TileTraits_, Index_ > +, cutlass::gemm::WmmaGemmGlobalIteratorCd< TileTraits_, Index_ > +, cutlass::TileLoadIterator< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > +, cutlass::TileStoreIterator< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ >
  • load_post_increment() -: cutlass::TileLoadIterator< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, IteratorFragment_, Skew_ > +: cutlass::gemm::GemmGlobalIteratorAb< TileTraits_, Index_ > +, cutlass::gemm::GemmGlobalIteratorCd< TileTraits_, Index_ > +, cutlass::gemm::WmmaGemmGlobalIteratorCd< TileTraits_, Index_ > +, cutlass::TileLoadIterator< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > +, cutlass::TileStoreIterator< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ > +, cutlass::ZipTileIterator< First_, Second_ >
  • diff --git a/docs/functions_func_m.html b/docs/functions_func_m.html index 2c68ec4f3..97f286560 100644 --- a/docs/functions_func_m.html +++ b/docs/functions_func_m.html @@ -71,23 +71,38 @@ $(function() {  

    - m -

    diff --git a/docs/functions_func_n.html b/docs/functions_func_n.html new file mode 100644 index 000000000..3b16224a8 --- /dev/null +++ b/docs/functions_func_n.html @@ -0,0 +1,89 @@ + + + + + + + +Cutlass: Class Members - Functions + + + + + + + + + + +
    +
    + + + + + + +
    +
    Cutlass +
    +
    CUDA Templates for Linear Algebra Subroutines and Solvers
    +
    +
    + + + + + + + +
    + +
    +
    + + +
    + +
    + +
    +  + +

    - n -

    +
    + + + + diff --git a/docs/functions_func_o.html b/docs/functions_func_o.html index fb7b39f73..a7c0b04d3 100644 --- a/docs/functions_func_o.html +++ b/docs/functions_func_o.html @@ -72,8 +72,12 @@ $(function() {

    - o -

    diff --git a/docs/functions_type_f.html b/docs/functions_type_f.html index a71defeb3..8ba4932b8 100644 --- a/docs/functions_type_f.html +++ b/docs/functions_type_f.html @@ -72,45 +72,59 @@ $(function() {

    - f -

    diff --git a/docs/functions_type_g.html b/docs/functions_type_g.html index 4ae366f31..05c838d38 100644 --- a/docs/functions_type_g.html +++ b/docs/functions_type_g.html @@ -73,19 +73,19 @@ $(function() {

    - g -

    diff --git a/docs/functions_type_k.html b/docs/functions_type_k.html new file mode 100644 index 000000000..b5e5add55 --- /dev/null +++ b/docs/functions_type_k.html @@ -0,0 +1,86 @@ + + + + + + + +Cutlass: Class Members - Typedefs + + + + + + + + + + +
    +
    + + + + + + +
    +
    Cutlass +
    +
    CUDA Templates for Linear Algebra Subroutines and Solvers
    +
    +
    + + + + + + + +
    + +
    +
    + + +
    + +
    + + + + + + diff --git a/docs/functions_type_l.html b/docs/functions_type_l.html index 2e7334f0d..abda026b1 100644 --- a/docs/functions_type_l.html +++ b/docs/functions_type_l.html @@ -72,13 +72,19 @@ $(function() {

    - l -

    diff --git a/docs/functions_type_m.html b/docs/functions_type_m.html index 043340a51..c76395492 100644 --- a/docs/functions_type_m.html +++ b/docs/functions_type_m.html @@ -71,26 +71,33 @@ $(function() {  

    - m -

    diff --git a/docs/functions_type_n.html b/docs/functions_type_n.html index bb5ad36c7..187630cfd 100644 --- a/docs/functions_type_n.html +++ b/docs/functions_type_n.html @@ -78,7 +78,7 @@ $(function() { diff --git a/docs/functions_type_o.html b/docs/functions_type_o.html index 42ed28139..d3d71d551 100644 --- a/docs/functions_type_o.html +++ b/docs/functions_type_o.html @@ -72,7 +72,7 @@ $(function() {

    - o -

    struct  cutlass::Fragment< Element_, kElements_, kAlignment_ >
     A template defining Fragment Concept. More...
     
    struct  cutlass::ZipFragment< First_, Second_ >
     A template defining Fragment Concept. More...
     

    Detailed Description

    Fragment Concept is a statically sized array for storing parts of tiles held by individual CUDA threads.

    @@ -94,7 +97,7 @@ Classes
    - +

    Classes

    struct  cutlass::TileLoadIterator< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, IteratorFragment_, Skew_ >
    struct  cutlass::TileLoadIterator< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ >
     An iterator implementing Tile Load Iterator Concept for loading a tile from memory. More...
     
    @@ -96,7 +96,7 @@ Classes
    - +

    Classes

    struct  cutlass::TileStoreIterator< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, IteratorFragment_, Skew_ >
    struct  cutlass::TileStoreIterator< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ >
     An iterator implementing Tile Store Iterator Concept for storing a tile to memory. More...
     
    @@ -96,7 +96,7 @@ Classes
    - +

    Classes

    struct  cutlass::TileTraits< Tile_, Delta_, Iterations_, ThreadOffset_ >
    struct  cutlass::TileTraits< Tile_, Delta_, Iterations_, ThreadOffset_, AccessSize >
     A template defining Tile Traits Concept. More...
     
    @@ -93,7 +93,7 @@ Classes
    @@ -107,7 +107,7 @@ Namespaces diff --git a/docs/hgemm__global__tile_8h_source.html b/docs/hgemm__global__tile_8h_source.html index bdd647d1a..8d7e02f65 100644 --- a/docs/hgemm__global__tile_8h_source.html +++ b/docs/hgemm__global__tile_8h_source.html @@ -76,34 +76,34 @@ $(function() {
    hgemm_global_tile.h
    -Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    30 #pragma once
    31 
    32 #include <cutlass/coord.h>
    34 #include <cutlass/matrix_traits.h>
    35 #include <cutlass/reshape_tile.h>
    36 
    37 namespace cutlass {
    38 namespace gemm {
    39 
    41 
    42 template <GemmOperand::Kind kOperand_,
    43  MatrixLayout::Kind kLayout_,
    44  typename Scalar_,
    45  typename Tile_,
    46  typename Threads_,
    47  int kAccessSize_>
    49  // Which GEMM operand?
    50  kOperand_,
    51  // The layout.
    52  kLayout_,
    53  // The scalar.
    54  Scalar_,
    55  // The tile.
    56  Tile_,
    57  // The threads.
    58  Threads_,
    59  // The number of scalars per LDG/STG.
    60  kAccessSize_> {
    64  typedef typename Base::Threads Threads;
    70  typedef Shape<Base::Tile::kH / Base::Threads::kH / 2,
    71  2,
    72  Base::Tile::kW / Base::Threads::kW,
    73  Base::Tile::kC / Base::kAccessSize>
    76  struct ThreadOffset {
    78  Coord<4> operator()() const {
    79  int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
    80  int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
    81 
    82  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
    83  }
    84  };
    85 };
    86 
    88 
    89 } // namespace gemm
    90 } // namespace cutlass
    Definition: convert.h:33
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    30 #pragma once
    31 
    32 #include "cutlass/coord.h"
    34 #include "cutlass/matrix_traits.h"
    35 #include "cutlass/reshape_tile.h"
    36 
    37 namespace cutlass {
    38 namespace gemm {
    39 
    41 
    42 template <GemmOperand::Kind kOperand_,
    43  MatrixLayout::Kind kLayout_,
    44  typename Scalar_,
    45  typename Tile_,
    46  typename Threads_,
    47  int kAccessSize_>
    49  // Which GEMM operand?
    50  kOperand_,
    51  // The layout.
    52  kLayout_,
    53  // The scalar.
    54  Scalar_,
    55  // The tile.
    56  Tile_,
    57  // The threads.
    58  Threads_,
    59  // The number of scalars per LDG/STG.
    60  kAccessSize_> {
    64  typedef typename Base::Threads Threads;
    70  typedef Shape<Base::VectorizedTile::kH / Base::Threads::kH / 2,
    71  2,
    72  Base::VectorizedTile::kW / Base::Threads::kW,
    73  Base::VectorizedTile::kC / Base::kAccessSize>
    76  struct ThreadOffset {
    78  Coord<4> operator()() const {
    79  int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
    80  int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
    81 
    82  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
    83  }
    84  };
    85 };
    86 
    88 
    89 } // namespace gemm
    90 } // namespace cutlass
    Shape< Base::VectorizedTile::kH/Base::Threads::kH/2, 2, Base::VectorizedTile::kW/Base::Threads::kW, Base::VectorizedTile::kC/Base::kAccessSize > Iterations
    The number of iterations needed to load/store the tile.
    Definition: hgemm_global_tile.h:74
    +
    Definition: convert.h:33
    Defines iterators for efficiently loading and storing to global memory.
    Definition: gemm_global_tile.h:70
    A Coord is a coordinate of arbitrary rank into a tensor or matrix.
    -
    CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
    Helper to make a 2-element coordinate.
    Definition: coord.h:241
    -
    Shape< Base::Tile::kH/Base::Threads::kH/2, 2, Base::Tile::kW/Base::Threads::kW, Base::Tile::kC/Base::kAccessSize > Iterations
    The number of iterations needed to load/store the tile.
    Definition: hgemm_global_tile.h:74
    +
    CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
    Helper to make a 2-element coordinate.
    Definition: coord.h:318
    +
    Shape< 1, 2, Base::VectorizedTile::kC > ThreadsDelta
    The threads strides.
    Definition: hgemm_global_tile.h:66
    Base::Threads Threads
    The threads.
    Definition: hgemm_global_tile.h:64
    static int const kH
    The height of the cube.
    Definition: shape.h:68
    CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
    Definition: hgemm_global_tile.h:78
    Shape< Base::Threads::kH *2, 1, Base::Threads::kW, Base::kAccessSize > Delta
    The strides in each dimension between different loads/stores.
    Definition: hgemm_global_tile.h:68
    -
    Shape< 1, 2, Base::Tile::kC > ThreadsDelta
    The threads strides.
    Definition: hgemm_global_tile.h:66
    Defines a type for restructuring a tile.
    #define CUTLASS_HOST_DEVICE
    Definition: cutlass.h:46
    GemmGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ > Base
    The base class.
    Definition: hgemm_global_tile.h:62
    Definition: hgemm_global_tile.h:48
    A Shape implementing Layout Concept describing the dimensions of a cube.
    Definition: shape.h:64
    +
    ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
    The threads shape.
    Definition: gemm_global_tile.h:88
    static int const kW
    The width of the cube.
    Definition: shape.h:70
    -
    Kind
    Definition: matrix_traits.h:36
    +
    Kind
    Enumeration defining fundamental contiguous layouts.
    Definition: matrix_traits.h:159
    static int const kAccessSize
    The number of scalars per LDG/STG.
    Definition: gemm_global_tile.h:80
    Computes the thread offset in (H, W) based on thread ID.
    Definition: hgemm_global_tile.h:76
    -
    Kind
    Definition: matrix_traits.h:43
    -
    ReshapeThreads< Tile, Threads_ >::Threads Threads
    The threads shape.
    Definition: gemm_global_tile.h:87
    +
    Kind
    Definition: matrix_traits.h:357
    Defines properties of matrices used to denote layout and operands to GEMM kernels.
    diff --git a/docs/hgemm__multiply__add_8h.html b/docs/hgemm__multiply__add_8h.html index 3c6c609e8..41ba8db9e 100644 --- a/docs/hgemm__multiply__add_8h.html +++ b/docs/hgemm__multiply__add_8h.html @@ -82,15 +82,15 @@ $(function() {

    Specialization implementing multiply-add operation on half-precision floating point fragments. More...

    -
    - - + +

    Classes

    struct  cutlass::gemm::ThreadMultiplyAdd< AccumulatorsPerThread_, ThreadsPerWarp_, half, half, half >
     Template performing matrix multiply-add operation within a thread. More...
    struct  cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >
     Template performing matrix multiply-add operation within a thread. More...
     
    diff --git a/docs/hgemm__multiply__add_8h_source.html b/docs/hgemm__multiply__add_8h_source.html index 73ef90409..40e849bfb 100644 --- a/docs/hgemm__multiply__add_8h_source.html +++ b/docs/hgemm__multiply__add_8h_source.html @@ -76,30 +76,31 @@ $(function() {
    hgemm_multiply_add.h
    -Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    29 #pragma once
    30 
    31 #include <cutlass/fragment.h>
    32 
    34 
    35 namespace cutlass {
    36 namespace gemm {
    37 
    39 
    41 template <typename AccumulatorsPerThread_, typename ThreadsPerWarp_>
    42 struct ThreadMultiplyAdd<AccumulatorsPerThread_, ThreadsPerWarp_, half, half, half> {
    46  typedef AccumulatorsPerThread_ AccumulatorsPerThread;
    48  typedef ThreadsPerWarp_ ThreadsPerWarp;
    52  typedef half ScalarA;
    56  typedef half ScalarB;
    60  typedef half ScalarC;
    63 
    65  static_assert(AccumulatorsPerThread::kH % 2 == 0, "Invalid size");
    66  static_assert(AccumulatorsPerThread::kW % 2 == 0, "Invalid size");
    67 
    69  CUTLASS_DEVICE ThreadMultiplyAdd() {}
    70 
    72  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
    73  FragmentB const& b,
    74  Accumulators const& c,
    75  Accumulators& d) {
    76 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
    77  // The inputs.
    78  __half2 const* a_half2 = reinterpret_cast<__half2 const*>(&a[0]);
    79  __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
    80  __half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
    81 
    82  // The output.
    83  __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
    84 
    85  for (int j = 0; j < AccumulatorsPerThread::kH / 2; ++j) {
    86  for (int i = 0; i < AccumulatorsPerThread::kW / 2; ++i) {
    87  // The offsets in the output fragment.
    88  int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i;
    89  int const k1 = (2 * j + 1) * (AccumulatorsPerThread::kW / 2) + i;
    90 
    91  // Compute the product a[i] * b[j].H0_H0.
    92  d_half2[k0] = __hfma2(a_half2[i], __low2half2(b_half2[j]), c_half2[k0]);
    93  // Compute the product a[i] * b[j].H1_H1.
    94  d_half2[k1] = __hfma2(a_half2[i], __high2half2(b_half2[j]), c_half2[k1]);
    95  }
    96  }
    97 #endif
    98  }
    99 };
    100 
    102 
    103 } // namespace gemm
    104 } // namespace cutlass
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    29 #pragma once
    30 
    31 #include "cutlass/fragment.h"
    32 
    34 
    35 namespace cutlass {
    36 namespace gemm {
    37 
    39 
    41 template <typename ThreadGemmShape_, typename ThreadsPerWarp_>
    42 struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, half, half, half> {
    46  typedef ThreadGemmShape_ ThreadGemmShape;
    50  typedef ThreadsPerWarp_ ThreadsPerWarp;
    54  typedef half ScalarA;
    58  typedef half ScalarB;
    62  typedef half ScalarC;
    65 
    67  static_assert(AccumulatorsPerThread::kH % 2 == 0, "Invalid size");
    68  static_assert(AccumulatorsPerThread::kW % 2 == 0, "Invalid size");
    69 
    71  CUTLASS_DEVICE ThreadMultiplyAdd() {}
    72 
    74  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
    75  FragmentB const& b,
    76  Accumulators const& c,
    77  Accumulators& d) {
    78 #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530
    79  // The inputs.
    80  __half2 const* a_half2 = reinterpret_cast<__half2 const*>(&a[0]);
    81  __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]);
    82  __half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]);
    83 
    84  // The output.
    85  __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]);
    86 
    87  for (int j = 0; j < AccumulatorsPerThread::kH / 2; ++j) {
    88  for (int i = 0; i < AccumulatorsPerThread::kW / 2; ++i) {
    89  // The offsets in the output fragment.
    90  int const k0 = (2 * j + 0) * (AccumulatorsPerThread::kW / 2) + i;
    91  int const k1 = (2 * j + 1) * (AccumulatorsPerThread::kW / 2) + i;
    92 
    93  // Compute the product a[i] * b[j].low.
    94  d_half2[k0] = __hfma2(a_half2[i], __low2half2(b_half2[j]), c_half2[k0]);
    95  // Compute the product a[i] * b[j].high.
    96  d_half2[k1] = __hfma2(a_half2[i], __high2half2(b_half2[j]), c_half2[k1]);
    97  }
    98  }
    99 #endif
    100  }
    101 };
    102 
    104 
    105 } // namespace gemm
    106 } // namespace cutlass
    CUTLASS_DEVICE ThreadMultiplyAdd()
    Make sure there&#39;s an even number of elements in both dimensions.
    Definition: hgemm_multiply_add.h:71
    +
    half ScalarC
    The type for C and D.
    Definition: hgemm_multiply_add.h:62
    Definition: convert.h:33
    -
    Fragment< half, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW > Accumulators
    The accumulators.
    Definition: hgemm_multiply_add.h:62
    -
    ShapeMul< AccumulatorsPerThread, ThreadsPerWarp >::Shape AccumulatorsPerWarp
    The number of accumulators per warp.
    Definition: hgemm_multiply_add.h:50
    -
    half ScalarC
    The type for C and D.
    Definition: hgemm_multiply_add.h:60
    -
    CUTLASS_DEVICE ThreadMultiplyAdd()
    Make sure there&#39;s an even number of elements in both dimensions.
    Definition: hgemm_multiply_add.h:69
    +
    Fragment< ScalarB, AccumulatorsPerThread::kH > FragmentB
    The fragment for B.
    Definition: hgemm_multiply_add.h:60
    +
    ThreadGemmShape_ ThreadGemmShape
    The number of accumulators per thread.
    Definition: hgemm_multiply_add.h:46
    +
    Shape< A_::kD *B_::kD, A_::kH *B_::kH, A_::kW *B_::kW, A_::kC *B_::kC > Shape
    Definition: shape.h:119
    A template defining Fragment Concept.
    Definition: fragment.h:99
    Template implementing matrix multiply-add operations on fragments.
    -
    Shape< 1, 1, 2, 1 > InstructionShape
    The shape of the instruction.
    Definition: hgemm_multiply_add.h:44
    - -
    ThreadsPerWarp_ ThreadsPerWarp
    The number of threads per warp.
    Definition: hgemm_multiply_add.h:48
    -
    AccumulatorsPerThread_ AccumulatorsPerThread
    The number of accumulators per thread.
    Definition: hgemm_multiply_add.h:46
    -
    #define static_assert(__e, __m)
    Definition: platform.h:145
    -
    CUTLASS_DEVICE void multiply_add(FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d)
    Multiply : d = a*b + c.
    Definition: hgemm_multiply_add.h:72
    +
    Shape< 1, 1, 2, 1 > InstructionShape
    The shape of the instruction.
    Definition: hgemm_multiply_add.h:44
    +
    ShapeMul< ThreadGemmShape, ThreadsPerWarp >::Shape AccumulatorsPerWarp
    The number of accumulators per warp.
    Definition: hgemm_multiply_add.h:52
    +
    Fragment< ScalarA, AccumulatorsPerThread::kW > FragmentA
    The fragment for A.
    Definition: hgemm_multiply_add.h:56
    +
    Fragment< half, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW > Accumulators
    The accumulators.
    Definition: hgemm_multiply_add.h:64
    +
    CUTLASS_DEVICE void multiply_add(FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d)
    Multiply : d = a*b + c.
    Definition: hgemm_multiply_add.h:74
    + +
    #define static_assert(__e, __m)
    Definition: platform.h:153
    A Shape implementing Layout Concept describing the dimensions of a cube.
    Definition: shape.h:64
    -
    Template performing matrix multiply-add operation within a thread.
    Definition: thread_multiply_add.h:43
    +
    Template performing matrix multiply-add operation within a thread.
    Definition: thread_multiply_add.h:44
    +
    ThreadGemmShape AccumulatorsPerThread
    Aliased for compatibility. Will be removed for CUTLASS v2.0.
    Definition: hgemm_multiply_add.h:48
    +
    ThreadsPerWarp_ ThreadsPerWarp
    The number of threads per warp.
    Definition: hgemm_multiply_add.h:50
    Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
    -
    Fragment< ScalarA, AccumulatorsPerThread::kW > FragmentA
    The fragment for A.
    Definition: hgemm_multiply_add.h:54
    -
    Fragment< ScalarB, AccumulatorsPerThread::kH > FragmentB
    The fragment for B.
    Definition: hgemm_multiply_add.h:58
    diff --git a/docs/hgemm__swizzle_8h.html b/docs/hgemm__swizzle_8h.html index aef7ac75e..93938799b 100644 --- a/docs/hgemm__swizzle_8h.html +++ b/docs/hgemm__swizzle_8h.html @@ -83,7 +83,7 @@ $(function() {

    Transposes a tile of 16b elements. Used by HGEMM to construct a K-strided layout in shared memory for multiplicands. More...

    #include <cuda_fp16.h>
    -#include <cutlass/fragment.h>
    +#include "cutlass/fragment.h"

    Go to the source code of this file.

    @@ -103,7 +103,7 @@ Namespaces

    @@ -102,7 +102,7 @@ Namespaces diff --git a/docs/hgemm__swizzle_8h_source.html b/docs/hgemm__swizzle_8h_source.html index bb76b510c..d882c10f3 100644 --- a/docs/hgemm__swizzle_8h_source.html +++ b/docs/hgemm__swizzle_8h_source.html @@ -76,14 +76,14 @@ $(function() {
    hgemm_swizzle.h
    -Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    29 #pragma once
    30 
    31 #include <cuda_fp16.h>
    32 #include <cutlass/fragment.h>
    33 
    34 namespace cutlass {
    35 namespace gemm {
    36 
    38 
    39 template <typename GlobalIterator_>
    40 struct HgemmSwizzle {
    42  typedef GlobalIterator_ GlobalIterator;
    44  typedef typename GlobalIterator::Fragment Fragment;
    46  typedef typename GlobalIterator::FragmentShape FragmentShape;
    47 
    52 
    55 
    57  static_assert(FragmentShape::kH == 2 && ShapeCount<FragmentShape>::kWc == 2, "Not multiple of 2");
    58 
    60  CUTLASS_DEVICE HgemmSwizzle() {}
    61 
    63  CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
    64  // Expose src/dst as int arrays.
    65  int const* src_int = reinterpret_cast<int const*>(&src[0]);
    66  int* dst_int = reinterpret_cast<int*>(&dst[0]);
    67 
    68  // Transpose the data.
    69  for (int d = 0; d < FragmentShape::kD; ++d) {
    70  // The indices to read two consecutive "rows".
    71  int const i0 = 2 * d + 0;
    72  int const i1 = 2 * d + 1;
    73 
    74  int a0 = src_int[i0];
    75  int a1 = src_int[i1];
    76 
    77  int b0, b1;
    78  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(a0), "r"(a1));
    79  asm volatile("prmt.b32 %0, %1, %2, 0x7632;" : "=r"(b1) : "r"(a0), "r"(a1));
    80 
    81  // The indices to store with "strides".
    82  int const j0 = 0 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
    83  int const j1 = 1 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
    84 
    85  dst_int[j0] = b0;
    86  dst_int[j1] = b1;
    87  }
    88  }
    89 };
    90 
    92 
    93 } // namespace gemm
    94 } // namespace cutlass
    GlobalIterator_ GlobalIterator
    The global iterator.
    Definition: hgemm_swizzle.h:42
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    29 #pragma once
    30 
    31 #include <cuda_fp16.h>
    32 #include "cutlass/fragment.h"
    33 
    34 namespace cutlass {
    35 namespace gemm {
    36 
    38 
    39 template <typename GlobalIterator_>
    40 struct HgemmSwizzle {
    42  typedef GlobalIterator_ GlobalIterator;
    44  typedef typename GlobalIterator::Fragment Fragment;
    46  typedef typename GlobalIterator::FragmentShape FragmentShape;
    47 
    52 
    55 
    57  static_assert(FragmentShape::kH == 2 && ShapeCount<FragmentShape>::kWc == 2, "Not multiple of 2");
    58 
    60  CUTLASS_DEVICE HgemmSwizzle() {}
    61 
    63  CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
    64  // Expose src/dst as int arrays.
    65  int const* src_int = reinterpret_cast<int const*>(&src[0]);
    66  int* dst_int = reinterpret_cast<int*>(&dst[0]);
    67 
    68  // Transpose the data.
    69  for (int d = 0; d < FragmentShape::kD; ++d) {
    70  // The indices to read two consecutive "rows".
    71  int const i0 = 2 * d + 0;
    72  int const i1 = 2 * d + 1;
    73 
    74  int a0 = src_int[i0];
    75  int a1 = src_int[i1];
    76 
    77  int b0, b1;
    78  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(a0), "r"(a1));
    79  asm volatile("prmt.b32 %0, %1, %2, 0x7632;" : "=r"(b1) : "r"(a0), "r"(a1));
    80 
    81  // The indices to store with "strides".
    82  int const j0 = 0 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
    83  int const j1 = 1 * (ShapeCount<FragmentShape>::kDhw / 2) + d;
    84 
    85  dst_int[j0] = b0;
    86  dst_int[j1] = b1;
    87  }
    88  }
    89 };
    90 
    92 
    93 } // namespace gemm
    94 } // namespace cutlass
    GlobalIterator_ GlobalIterator
    The global iterator.
    Definition: hgemm_swizzle.h:42
    Definition: convert.h:33
    -
    std::is_same (false specialization)
    Definition: platform.h:412
    +
    std::is_same (false specialization)
    Definition: platform.h:420
    CUTLASS_DEVICE HgemmSwizzle()
    The src/dst must be half fragments.
    Definition: hgemm_swizzle.h:60
    CUTLASS_DEVICE void transform(Fragment const &src, Fragment &dst)
    Transform a fragment.
    Definition: hgemm_swizzle.h:63
    Fragment InputFragment
    The input fragment.
    Definition: hgemm_swizzle.h:49
    Fragment OutputFragment
    The output fragment.
    Definition: hgemm_swizzle.h:51
    -
    #define static_assert(__e, __m)
    Definition: platform.h:145
    +
    #define static_assert(__e, __m)
    Definition: platform.h:153
    GlobalIterator::Fragment Fragment
    The source fragment.
    Definition: hgemm_swizzle.h:44
    Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
    GlobalIterator::FragmentShape FragmentShape
    The shape of the source fragment.
    Definition: hgemm_swizzle.h:46
    @@ -92,7 +92,7 @@ $(function() {
    diff --git a/docs/hgemm__traits_8h.html b/docs/hgemm__traits_8h.html index 283ceb750..bb8e72d99 100644 --- a/docs/hgemm__traits_8h.html +++ b/docs/hgemm__traits_8h.html @@ -82,23 +82,23 @@ $(function() {

    Defies structural properties of half-precision GEMM computation. More...

    -
    - + @@ -120,9 +120,9 @@ Classes - + - +

    Classes

    struct  cutlass::gemm::HgemmConfig< OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_ >
    struct  cutlass::gemm::HgemmConfig< OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_ >
     
    struct  cutlass::gemm::HgemmTransformerA< kLayout_, Iterator_ >
     
     
    struct  cutlass::gemm::HgemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ >
     
    struct  cutlass::gemm::HgemmTraitsHelper< kLayoutA_, kLayoutB_, OutputTile_, EpilogueFunctor_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_, Index_ >
    struct  cutlass::gemm::HgemmTraitsHelper< kLayoutA_, kLayoutB_, OutputTile_, EpilogueFunctor_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_, Index_ >
     
    struct  cutlass::gemm::HgemmTraits< kLayoutA_, kLayoutB_, OutputTile_, EpilogueFunctor_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_, Index_, Helper_ >
    struct  cutlass::gemm::HgemmTraits< kLayoutA_, kLayoutB_, OutputTile_, EpilogueFunctor_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_, Index_, Helper_ >
     
    diff --git a/docs/hgemm__traits_8h_source.html b/docs/hgemm__traits_8h_source.html index 0d12493ec..db1554c86 100644 --- a/docs/hgemm__traits_8h_source.html +++ b/docs/hgemm__traits_8h_source.html @@ -76,89 +76,87 @@ $(function() {
    hgemm_traits.h
    -Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    28 #pragma once
    29 
    30 #include <cutlass/convert.h>
    31 #include <cutlass/reshape_tile.h>
    32 
    33 #include <cutlass/gemm/gemm.h>
    42 
    43 namespace cutlass {
    44 namespace gemm {
    45 
    47 
    48 template <
    50  typename OutputTile_,
    52  typename AccumulatorsPerThread_,
    54  int kScalarsPerLdgA_ = 2,
    56  int kScalarsPerLdgB_ = 2>
    58  : public GemmConfig<
    60  half,
    62  half,
    64  half,
    66  half,
    68  OutputTile_,
    70  ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, half, half, half>,
    72  kScalarsPerLdgA_,
    74  kScalarsPerLdgA_,
    76  8,
    78  kScalarsPerLdgB_,
    80  kScalarsPerLdgB_,
    82  8,
    84  2,
    86  8,
    88  2,
    90  2> {};
    91 
    93 
    94 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
    96 
    97 template <typename Iterator_>
    98 struct HgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
    100 };
    101 
    102 template <typename Iterator_>
    103 struct HgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
    105 };
    106 
    108 
    109 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
    111 
    112 template <typename Iterator_>
    113 struct HgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
    115 };
    116 
    117 template <typename Iterator_>
    118 struct HgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
    120 };
    121 
    123 
    124 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
    125 struct HgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
    126 
    128 
    129 template <typename GemmConfig_>
    130 struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
    131  : public GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
    134 
    138  // The layout.
    140  // The pointer.
    141  half const,
    142  // The tile has size MxK in GEMM's terminology.
    144  // The threads are distributed as (threads / K ) x K (the traits may reorganize).
    145  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
    146  // The number of scalars per LDG (LDG.32 or LDG.128, etc)
    147  GemmConfig_::kScalarsPerLdgA>
    149 
    152  // The pointer.
    153  half,
    154  // The tile has size KxM in GEMM's terminology.
    155  Shape<GemmConfig_::kStages,
    156  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
    157  GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
    158  // The threads are distributed as warps x 32(the traits may reorganize).
    159  typename GlobalTileTraits::Threads,
    160  // The number of scalars per STS (STS.32 or STS.128, etc).
    161  2,
    162  // The skew to avoid bank conflicts added in the tile W dimension.
    163  128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2>
    165 
    168  // The pointer.
    169  half const,
    170  // The output tile size.
    171  typename GemmConfig_::OutputTile,
    172  // The number of warps.
    173  typename GemmConfig_::Warps,
    174  // The number of threads per warp.
    175  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
    176  // The shape of the FMA instruction.
    177  typename GemmConfig_::InstructionShape,
    178  // The number of stages.
    179  GemmConfig_::kStages,
    180  // The number of scalars per LDS.
    181  8,
    182  // The skew.
    183  SharedStoreTileTraits::kSkew>
    185 };
    186 
    188 
    189 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
    190 struct HgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
    191 
    193 
    194 template <typename GemmConfig_>
    195 struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
    196  : public GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
    199 
    203  // The layout.
    205  // The pointer.
    206  half const,
    207  // The tile has size KxN in GEMM's terminology.
    209  // The threads are distributed as (threads / K) x K (the traits may reorganize).
    210  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
    211  // The number of scalars per LDG (LDG.32 or LDG.128, etc)
    212  GemmConfig_::kScalarsPerLdgB>
    214 
    217  // The pointer.
    218  half,
    219  // The tile has size KxN in GEMM's terminology.
    220  Shape<GemmConfig_::kStages,
    221  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
    222  GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
    223  // The threads are distributed as (threads / K) x K (the traits may reorganize).
    224  typename GlobalTileTraits::Threads,
    225  // The number of scalars per STS (STS.32 or STS.128, etc).
    226  2,
    227  // The skew to avoid bank conflicts added in the tile W dimension.
    228  128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2>
    230 
    233  // The pointer.
    234  half const,
    235  // The output tile size.
    236  typename GemmConfig_::OutputTile,
    237  // The number of warps.
    238  typename GemmConfig_::Warps,
    239  // The number of threads per warp.
    240  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
    241  // The shape of the FMA instruction.
    242  typename GemmConfig_::InstructionShape,
    243  // The number of stages.
    244  GemmConfig_::kStages,
    245  // The number of scalars per LDS.
    246  8,
    247  // The skew.
    248  SharedStoreTileTraits::kSkew>
    250 };
    251 
    253 
    254 template <
    256  MatrixLayout::Kind kLayoutA_,
    258  MatrixLayout::Kind kLayoutB_,
    260  typename OutputTile_,
    262  typename EpilogueFunctor_,
    264  typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
    266  int kScalarsPerLdgA_ = 2,
    268  int kScalarsPerLdgB_ = 2,
    270  typename Index_ = int>
    279 
    284  typedef typename HgemmTransformerA<GemmTileTraitsHelperA::kLayout,
    287  typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
    288  typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
    295 
    299  // The default transformer for B.
    300  typedef typename HgemmTransformerB<GemmTileTraitsHelperB::kLayout,
    303  typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
    304  typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
    311 
    313  typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
    314  typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
    321  typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
    322  typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
    328 
    333 
    338 };
    339 
    341 
    342 template <
    344  MatrixLayout::Kind kLayoutA_,
    346  MatrixLayout::Kind kLayoutB_,
    348  typename OutputTile_ = Shape<8, 128, 128>,
    350  typename EpilogueFunctor_ = LinearScaling<half>,
    352  typename AccumulatorsPerThread_ = Shape<8, 8, 16>,
    354  int kScalarsPerLdgA_ = 2,
    356  int kScalarsPerLdgB_ = 2,
    358  typename Index_ = int,
    360  typename Helper_ = HgemmTraitsHelper<kLayoutA_,
    361  kLayoutB_,
    362  OutputTile_,
    363  EpilogueFunctor_,
    364  AccumulatorsPerThread_,
    365  kScalarsPerLdgA_,
    366  kScalarsPerLdgB_,
    367  Index_> >
    368 struct HgemmTraits : public GemmTraits<
    369  // The config.
    370  typename Helper_::GemmConfig,
    371  // The stream to load A from global memory to shared memory.
    372  typename Helper_::GlobalLoadStreamA,
    373  // The stream to load B from global memory to shared memory.
    374  typename Helper_::GlobalLoadStreamB,
    375  // The stream to load A from shared memory.
    376  typename Helper_::SharedLoadStreamA,
    377  // The stream to load B from shared memory.
    378  typename Helper_::SharedLoadStreamB,
    379  // The epilogue.
    380  typename Helper_::Epilogue,
    381  // The block swizzle to reorganize the grid.
    382  IdentityBlockSwizzle,
    383  // The index.
    384  Index_,
    385  // The tool used to clear accumulators.
    386  typename Helper_::ClearAccumulators> {};
    387 
    389 
    390 } // namespace gemm
    391 } // namespace cutlass
    GemmGlobalIteratorAb< typename GemmTileTraitsHelperA::GlobalTileTraits, Index_ > GlobalLoadIteratorA
    The iterator to load A from global memory.
    Definition: hgemm_traits.h:282
    -
    Definition: load_store.h:42
    -
    HgemmSwizzle< Iterator_ > Transformer
    Definition: hgemm_traits.h:119
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    28 #pragma once
    29 
    30 #include "cutlass/convert.h"
    31 #include "cutlass/reshape_tile.h"
    32 
    33 #include "cutlass/gemm/gemm.h"
    42 
    43 namespace cutlass {
    44 namespace gemm {
    45 
    47 
    48 template <
    50  typename OutputTile_,
    52  typename ThreadGemmShape_,
    54  int kScalarsPerLdgA_ = 2,
    56  int kScalarsPerLdgB_ = 2>
    57 struct HgemmConfig : public GemmConfig<
    59  half,
    61  half,
    63  half,
    65  half,
    67  OutputTile_,
    69  ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, half, half, half>,
    71  kScalarsPerLdgA_,
    73  kScalarsPerLdgA_,
    75  8,
    77  kScalarsPerLdgB_,
    79  kScalarsPerLdgB_,
    81  8,
    83  2,
    85  8,
    87  2,
    89  2,
    91  false,
    93  true,
    95  false
    96  > {};
    97 
    99 
    100 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
    102 
    103 template <typename Iterator_>
    104 struct HgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
    106 };
    107 
    108 template <typename Iterator_>
    109 struct HgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
    111 };
    112 
    114 
    115 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
    117 
    118 template <typename Iterator_>
    119 struct HgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
    121 };
    122 
    123 template <typename Iterator_>
    124 struct HgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
    126 };
    127 
    129 
    130 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
    131 struct HgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
    132 
    134 
    135 template <typename GemmConfig_>
    136 struct HgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_>
    137  : public GemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_> {
    140 
    144  // The layout.
    146  // The pointer.
    147  half const,
    148  // The tile has size MxK in GEMM's terminology.
    150  // The threads are distributed as (threads / K ) x K (the traits may reorganize).
    151  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
    152  // The number of scalars per LDG (LDG.32 or LDG.128, etc)
    153  GemmConfig_::kScalarsPerLdgA>
    155 
    156  static int const kSkewA = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
    157 
    160  // The pointer.
    161  half,
    162  // The tile has size KxM in GEMM's terminology.
    163  Shape<GemmConfig_::kStages,
    164  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
    165  GemmConfig_::OutputTile::kW * GemmConfig_::InstructionShape::kD>,
    166  // The threads are distributed as warps x 32(the traits may reorganize).
    167  typename GlobalTileTraits::Threads,
    168  // The number of scalars per STS (STS.32 or STS.128, etc).
    169  2,
    170  // The skew to avoid bank conflicts added in the tile W dimension.
    171  kSkewA<GemmConfig_::kScalarsPerLdsA ? GemmConfig_::kScalarsPerLdsA : kSkewA>
    172  SharedStoreTileTraits;
    173 
    176  // The pointer.
    177  half const,
    178  // The output tile size.
    179  typename GemmConfig_::OutputTile,
    180  // The number of warps.
    181  typename GemmConfig_::Warps,
    182  // The number of threads per warp.
    183  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
    184  // The shape of the FMA instruction.
    185  typename GemmConfig_::InstructionShape,
    186  // The number of stages.
    187  GemmConfig_::kStages,
    188  // The number of scalars per LDS.
    189  8,
    190  // The skew.
    191  SharedStoreTileTraits::kSkew>
    192  SharedLoadTileTraits;
    193 };
    194 
    196 
    197 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
    198 struct HgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
    199 
    201 
    202 template <typename GemmConfig_>
    203 struct HgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_>
    204  : public GemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_> {
    207 
    211  // The layout.
    213  // The pointer.
    214  half const,
    215  // The tile has size KxN in GEMM's terminology.
    217  // The threads are distributed as (threads / K) x K (the traits may reorganize).
    218  Shape<1, GemmConfig_::kThreads / GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD>,
    219  // The number of scalars per LDG (LDG.32 or LDG.128, etc)
    220  GemmConfig_::kScalarsPerLdgB>
    222 
    223  static int const kSkewB = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2;
    224 
    227  // The pointer.
    228  half,
    229  // The tile has size KxN in GEMM's terminology.
    230  Shape<GemmConfig_::kStages,
    231  GemmConfig_::OutputTile::kD / GemmConfig_::InstructionShape::kD,
    232  GemmConfig_::OutputTile::kH * GemmConfig_::InstructionShape::kD>,
    233  // The threads are distributed as (threads / K) x K (the traits may reorganize).
    234  typename GlobalTileTraits::Threads,
    235  // The number of scalars per STS (STS.32 or STS.128, etc).
    236  2,
    237  // The skew to avoid bank conflicts added in the tile W dimension.
    238  kSkewB<GemmConfig_::kScalarsPerLdsB ? GemmConfig_::kScalarsPerLdsB : kSkewB>
    239  SharedStoreTileTraits;
    240 
    243  // The pointer.
    244  half const,
    245  // The output tile size.
    246  typename GemmConfig_::OutputTile,
    247  // The number of warps.
    248  typename GemmConfig_::Warps,
    249  // The number of threads per warp.
    250  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
    251  // The shape of the FMA instruction.
    252  typename GemmConfig_::InstructionShape,
    253  // The number of stages.
    254  GemmConfig_::kStages,
    255  // The number of scalars per LDS.
    256  8,
    257  // The skew.
    258  SharedStoreTileTraits::kSkew>
    259  SharedLoadTileTraits;
    260 };
    261 
    263 
    264 template <
    266  MatrixLayout::Kind kLayoutA_,
    268  MatrixLayout::Kind kLayoutB_,
    270  typename OutputTile_,
    272  typename EpilogueFunctor_,
    274  typename ThreadGemmShape_,
    276  int kScalarsPerLdgA_ = 2,
    278  int kScalarsPerLdgB_ = 2,
    280  typename Index_ = int>
    288 
    293  typedef typename HgemmTransformerA<GemmTileTraitsHelperA::kLayout,
    296  typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
    297  typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
    307 
    311  // The default transformer for B.
    312  typedef typename HgemmTransformerB<GemmTileTraitsHelperB::kLayout,
    315  typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
    316  typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
    326 
    328  typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
    329  typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
    336  typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
    337  typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
    343 
    348 
    353 };
    354 
    356 
    357 template <
    359  MatrixLayout::Kind kLayoutA_,
    361  MatrixLayout::Kind kLayoutB_,
    363  typename OutputTile_ = Shape<8, 128, 128>,
    365  typename EpilogueFunctor_ = LinearScaling<half>,
    367  typename ThreadGemmShape_ = Shape<8, 8, 16>,
    369  int kScalarsPerLdgA_ = 2,
    371  int kScalarsPerLdgB_ = 2,
    373  typename Index_ = int,
    375  typename Helper_ = HgemmTraitsHelper<kLayoutA_,
    376  kLayoutB_,
    377  OutputTile_,
    378  EpilogueFunctor_,
    379  ThreadGemmShape_,
    380  kScalarsPerLdgA_,
    381  kScalarsPerLdgB_,
    382  Index_> >
    383 struct HgemmTraits : public GemmTraits<
    384  // The config.
    385  typename Helper_::GemmConfig,
    386  // The stream to load A from global memory to shared memory.
    387  typename Helper_::GlobalLoadStreamA,
    388  // The stream to load B from global memory to shared memory.
    389  typename Helper_::GlobalLoadStreamB,
    390  // The stream to load A from shared memory.
    391  typename Helper_::SharedLoadStreamA,
    392  // The stream to load B from shared memory.
    393  typename Helper_::SharedLoadStreamB,
    394  // The epilogue.
    395  typename Helper_::Epilogue,
    396  // The block swizzle to reorganize the grid.
    397  IdentityBlockSwizzle,
    398  // The index.
    399  Index_,
    400  // The tool used to clear accumulators.
    401  typename Helper_::ClearAccumulators> {};
    402 
    404 
    405 } // namespace gemm
    406 } // namespace cutlass
    SharedLoadStream< SharedLoadIteratorB > SharedLoadStreamB
    The stream to load B from shared memory.
    Definition: hgemm_traits.h:342
    +
    GemmGlobalIteratorAb< typename GemmTileTraitsHelperB::GlobalTileTraits, Index_ > GlobalLoadIteratorB
    The iterator to load B from global memory.
    Definition: hgemm_traits.h:310
    +
    Definition: load_store.h:41
    +
    HgemmSwizzle< Iterator_ > Transformer
    Definition: hgemm_traits.h:125
    Definition: convert.h:33
    -
    Definition: gemm_shared_tile.h:129
    +
    HgemmConfig< OutputTile_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_ > GemmConfig
    The HGEMM config.
    Definition: hgemm_traits.h:283
    +
    Definition: gemm_shared_tile.h:128
    -
    Definition: gemm_epilogue.h:53
    +
    Definition: gemm_epilogue.h:42
    Defines iterators for efficiently loading and storing to global memory.
    -
    GemmGlobalIteratorAb< typename GemmTileTraitsHelperB::GlobalTileTraits, Index_ > GlobalLoadIteratorB
    The iterator to load B from global memory.
    Definition: hgemm_traits.h:298
    -
    ClearAccumulators< typename MultiplyAdd::ScalarC > ClearAccumulators
    The object to clear accumulators.
    Definition: hgemm_traits.h:332
    +
    SimplifiedGemmEpilogueTraits< GemmConfig, EpilogueFunctor_, Index_ > GemmEpilogueTraits
    The traits class for the epilogue.
    Definition: hgemm_traits.h:350
    Defines structural properties of complete GEMM computation.
    -
    TileStoreIterator< typename GemmTileTraitsHelperA::SharedStoreTileTraits, typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorA
    The iterator to store A to shared memory.
    Definition: hgemm_traits.h:291
    -
    GlobalLoadStream< GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA > GlobalLoadStreamA
    The stream to load A from global memory to shared memory.
    Definition: hgemm_traits.h:294
    -
    HgemmCrosswiseGlobalTileTraits< GemmOperand::kB, MatrixLayout::kColumnMajor, half const, Shape< 1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
    The traits class to build the iterator to load data from global memory for B^N.
    Definition: hgemm_traits.h:213
    -
    Definition: hgemm_traits.h:95
    -
    GemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ > Base
    The base config.
    Definition: hgemm_traits.h:198
    -
    SharedLoadStream< SharedLoadIteratorA > SharedLoadStreamA
    The stream to load A from shared memory.
    Definition: hgemm_traits.h:319
    -
    Convert< typename Iterator_::Fragment, typename Iterator_::Fragment > Transformer
    Definition: hgemm_traits.h:99
    -
    Definition: hgemm_traits.h:368
    -
    HgemmSwizzle< Iterator_ > Transformer
    Definition: hgemm_traits.h:104
    -
    Definition: tile_iterator.h:62
    -
    Definition: gemm_shared_tile.h:198
    -
    TileLoadIterator< typename GemmTileTraitsHelperB::SharedLoadTileTraits, typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorB
    The iterator to load B from shared memory.
    Definition: hgemm_traits.h:325
    -
    Definition: gemm_global_tile.h:159
    -
    GemmEpilogue< GemmEpilogueTraits > Epilogue
    The epilogue.
    Definition: hgemm_traits.h:337
    -
    HgemmTransformerA< GemmTileTraitsHelperA::kLayout, GlobalLoadIteratorA >::Transformer GlobalTransformerA
    The default transformer for A.
    Definition: hgemm_traits.h:285
    +
    HgemmCrosswiseGlobalTileTraits< GemmOperand::kB, MatrixLayout::kColumnMajor, half const, Shape< 1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
    The traits class to build the iterator to load data from global memory for B^N.
    Definition: hgemm_traits.h:221
    +
    Definition: hgemm_traits.h:101
    +
    GemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ > Base
    The base config.
    Definition: hgemm_traits.h:206
    +
    GemmEpilogue< GemmEpilogueTraits > Epilogue
    The epilogue.
    Definition: hgemm_traits.h:352
    +
    Convert< typename Iterator_::Fragment, typename Iterator_::Fragment > Transformer
    Definition: hgemm_traits.h:105
    +
    GlobalLoadStream< GemmOperand::kA, GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA > GlobalLoadStreamA
    The stream to load A from global memory to shared memory.
    Definition: hgemm_traits.h:306
    +
    Definition: hgemm_traits.h:383
    +
    HgemmSwizzle< Iterator_ > Transformer
    Definition: hgemm_traits.h:110
    +
    TileLoadIterator< typename GemmTileTraitsHelperB::SharedLoadTileTraits, typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorB
    The iterator to load B from shared memory.
    Definition: hgemm_traits.h:340
    +
    HgemmTransformerA< GemmTileTraitsHelperA::kLayout, GlobalLoadIteratorA >::Transformer GlobalTransformerA
    The default transformer for A.
    Definition: hgemm_traits.h:294
    +
    Definition: tile_iterator.h:65
    +
    Definition: gemm_shared_tile.h:200
    +
    Definition: gemm_global_tile.h:163
    Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
    -
    Definition: gemm_global_stream.h:161
    -
    Definition: gemm_traits.h:273
    -
    Definition: hgemm_traits.h:125
    -
    Describes layouts of matrices.
    Definition: matrix_traits.h:35
    -
    SharedLoadStream< SharedLoadIteratorB > SharedLoadStreamB
    The stream to load B from shared memory.
    Definition: hgemm_traits.h:327
    -
    Definition: hgemm_traits.h:110
    -
    GemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_ > Base
    The base config.
    Definition: hgemm_traits.h:133
    -
    TileLoadIterator< typename GemmTileTraitsHelperA::SharedLoadTileTraits, typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorA
    The iterator to load A from shared memory.
    Definition: hgemm_traits.h:317
    -
    An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
    Definition: tile_iterator.h:302
    -
    SimplifiedGemmEpilogueTraits< GemmConfig, EpilogueFunctor_, Index_ > GemmEpilogueTraits
    The traits class for the epilogue.
    Definition: hgemm_traits.h:335
    +
    Definition: gemm_global_stream.h:52
    +
    Definition: gemm_traits.h:191
    +
    Definition: hgemm_traits.h:131
    +
    HgemmTileTraitsHelperA< kLayoutA_, GemmConfig > GemmTileTraitsHelperA
    The GEMM config for A.
    Definition: hgemm_traits.h:285
    +
    Defines data layouts of various matrix formats usable by TensorRef and other classes.
    Definition: matrix_traits.h:156
    +
    Definition: hgemm_traits.h:116
    +
    GemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_ > Base
    The base config.
    Definition: hgemm_traits.h:139
    +
    An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
    Definition: tile_iterator.h:399
    Defines iterators for efficiently loading and storing tiles to and from shared memory.
    -
    Definition: matrix_traits.h:36
    - -
    Definition: gemm_shared_stream.h:44
    +
    Definition: matrix_traits.h:159
    + +
    Definition: gemm_shared_stream.h:45
    +
    HgemmTransformerB< GemmTileTraitsHelperB::kLayout, GlobalLoadIteratorB >::Transformer GlobalTransformerB
    Definition: hgemm_traits.h:313
    Defines a type for restructuring a tile.
    +
    ClearAccumulators< typename MultiplyAdd::ScalarC > ClearAccumulators
    The object to clear accumulators.
    Definition: hgemm_traits.h:347
    Specialization implementing multiply-add operation on half-precision floating point fragments...
    -
    Definition: gemm_traits.h:79
    +
    Definition: gemm_config.h:76
    +
    TileLoadIterator< typename GemmTileTraitsHelperA::SharedLoadTileTraits, typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorA
    The iterator to load A from shared memory.
    Definition: hgemm_traits.h:332
    Transposes a tile of 16b elements. Used by HGEMM to construct a K-strided layout in shared memory for...
    -
    Definition: gemm_traits.h:137
    -
    GemmSharedLoadTileBTraits< half const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, 8, SharedStoreTileTraits::kSkew > SharedLoadTileTraits
    The traits class to build the iterator to load from shared memory for B^N.
    Definition: hgemm_traits.h:249
    -
    Definition: matrix_traits.h:43
    -
    HgemmConfig< OutputTile_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_ > GemmConfig
    The HGEMM config.
    Definition: hgemm_traits.h:274
    -
    Definition: hgemm_traits.h:190
    -
    GlobalLoadStream< GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB > GlobalLoadStreamB
    The stream to load B from global memory to shared memory.
    Definition: hgemm_traits.h:310
    -
    GemmConfig::MultiplyAdd MultiplyAdd
    The functor to do the multiply-add in the main loop.
    Definition: hgemm_traits.h:330
    -
    HgemmTileTraitsHelperB< kLayoutB_, GemmConfig > GemmTileTraitsHelperB
    The GEMM config for B.
    Definition: hgemm_traits.h:278
    -
    Definition: gemm_traits.h:428
    +
    Definition: gemm_traits.h:52
    +
    Definition: matrix_traits.h:357
    +
    Definition: hgemm_traits.h:198
    +
    GemmConfig::MultiplyAdd MultiplyAdd
    The functor to do the multiply-add in the main loop.
    Definition: hgemm_traits.h:345
    +
    Definition: gemm_traits.h:349
    +
    HgemmTileTraitsHelperB< kLayoutB_, GemmConfig > GemmTileTraitsHelperB
    The GEMM config for B.
    Definition: hgemm_traits.h:287
    Definition: hgemm_global_tile.h:48
    A Shape implementing Layout Concept describing the dimensions of a cube.
    Definition: shape.h:64
    -
    Definition: gemm_epilogue_traits.h:300
    -
    GemmSharedLoadTileATraits< half const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, 8, SharedStoreTileTraits::kSkew > SharedLoadTileTraits
    The traits class to build the iterator to load from shared memory for A^T.
    Definition: hgemm_traits.h:184
    -
    HgemmTileTraitsHelperA< kLayoutA_, GemmConfig > GemmTileTraitsHelperA
    The GEMM config for A.
    Definition: hgemm_traits.h:276
    -
    Template performing matrix multiply-add operation within a thread.
    Definition: thread_multiply_add.h:43
    -
    Definition: matrix_traits.h:36
    -
    Kind
    Definition: matrix_traits.h:36
    -
    HgemmTransformerB< GemmTileTraitsHelperB::kLayout, GlobalLoadIteratorB >::Transformer GlobalTransformerB
    Definition: hgemm_traits.h:301
    - -
    Definition: hgemm_traits.h:271
    -
    HgemmCrosswiseGlobalTileTraits< GemmOperand::kA, MatrixLayout::kRowMajor, half const, Shape< 1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
    The traits class to build the iterator to load data from global memory for A^T.
    Definition: hgemm_traits.h:148
    +
    Definition: gemm_epilogue_traits.h:323
    +
    ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
    The threads shape.
    Definition: gemm_global_tile.h:88
    +
    Template performing matrix multiply-add operation within a thread.
    Definition: thread_multiply_add.h:44
    +
    Definition: matrix_traits.h:159
    +
    Kind
    Enumeration defining fundamental contiguous layouts.
    Definition: matrix_traits.h:159
    +
    GemmGlobalIteratorAb< typename GemmTileTraitsHelperA::GlobalTileTraits, Index_ > GlobalLoadIteratorA
    The iterator to load A from global memory.
    Definition: hgemm_traits.h:291
    + +
    Definition: hgemm_traits.h:281
    +
    HgemmCrosswiseGlobalTileTraits< GemmOperand::kA, MatrixLayout::kRowMajor, half const, Shape< 1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD >, Shape< 1, GemmConfig_::kThreads/GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kD >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
    The traits class to build the iterator to load data from global memory for A^T.
    Definition: hgemm_traits.h:154
    Tile traits used to construct global tile iterator for HGEMM. This is intended to partition the threa...
    -
    Functor to compute linear combination of fragments.
    Definition: linear_scaling.h:40
    +
    Functor to compute linear combination of fragments.
    Definition: linear_scaling.h:51
    Definition: convert.h:38
    -
    Definition: matrix_traits.h:43
    +
    Definition: matrix_traits.h:357
    Implements a software-pipelined efficient GEMM.
    -
    ReshapeThreads< Tile, Threads_ >::Threads Threads
    The threads shape.
    Definition: gemm_global_tile.h:87
    +
    GlobalLoadStream< GemmOperand::kB, GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB > GlobalLoadStreamB
    The stream to load B from global memory to shared memory.
    Definition: hgemm_traits.h:325
    +
    SharedLoadStream< SharedLoadIteratorA > SharedLoadStreamA
    The stream to load A from shared memory.
    Definition: hgemm_traits.h:334
    Defines structural properties of the GEMM epilogue.
    +
    TileStoreIterator< typename GemmTileTraitsHelperB::SharedStoreTileTraits, typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorB
    The iterator to store B to shared memory.
    Definition: hgemm_traits.h:319
    +
    TileStoreIterator< typename GemmTileTraitsHelperA::SharedStoreTileTraits, typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorA
    The iterator to store A to shared memory.
    Definition: hgemm_traits.h:300
    Definition: hgemm_swizzle.h:40
    Defines conversion operations among Fragments of different base type.
    -
    Convert< typename Iterator_::Fragment, typename Iterator_::Fragment > Transformer
    Definition: hgemm_traits.h:114
    +
    Convert< typename Iterator_::Fragment, typename Iterator_::Fragment > Transformer
    Definition: hgemm_traits.h:120
    Definition: hgemm_traits.h:57
    -
    An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
    Definition: tile_iterator.h:620
    -
    TileStoreIterator< typename GemmTileTraitsHelperB::SharedStoreTileTraits, typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorB
    The iterator to store B to shared memory.
    Definition: hgemm_traits.h:307
    +
    An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
    Definition: tile_iterator.h:836
    diff --git a/docs/hierarchy.html b/docs/hierarchy.html index 25ba6bdab..865698e5a 100644 --- a/docs/hierarchy.html +++ b/docs/hierarchy.html @@ -73,7 +73,7 @@ $(function() {
    This inheritance list is sorted roughly, but not completely, alphabetically:
    -
    [detail level 123]

    @@ -135,7 +135,7 @@ Namespaces

    +
    [detail level 1234]
    @@ -94,316 +94,389 @@ $(function() { - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
     Ccutlass::platform::aligned_chunk< Align >
     Ccutlass::platform::aligned_storage< Len, Align >Std::aligned_storage
     Ccutlass::AlignedStruct< kAlignment_ >
     Ccutlass::platform::alignment_of< ulong4 >
     Ccutlass::platform::alignment_of< ulonglong2 >
     Ccutlass::platform::alignment_of< ulonglong4 >
     Ccutlass::gemm::ClearAccumulators< Scalar_, kLanes_ >
     Ccutlass::ComputeOffsetFromShape< Shape_ >Compute the offset for the given coordinates in a cube
     Ccutlass::ComputeOffsetFromShape< Shape< 1, kSh_, kSw_, 1 > >Compute the offset for the given coordinates in a cube with one channel and a depth of 1
     Ccutlass::ComputeOffsetFromShape< Shape< 1, kSh_, kSw_, kSc_ > >Compute the offset for the given coordinates in a cube with a depth of 1
     Ccutlass::ComputeOffsetFromStrides< Strides_ >Compute the offset for the given coordinates in a cube
     Ccutlass::ComputeOffsetFromStrides< Shape< 1, S_h_, S_w_, 1 > >Compute the offset for the given coordinates in a cube with one channel and a depth of 1
     Ccutlass::ComputeOffsetFromStrides< Shape< 1, S_h_, S_w_, S_c_ > >Compute the offset for the given coordinates in a cube with a depth of 1
     Ccutlass::ComputeThreadOffsetFromStrides< Threads_, Strides_ >Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_. Afterwards compute the offset of those coordinates using Strides_
     Ccutlass::ComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, 1 >, Shape< 1, S_h_, S_w_, 1 > >Specialization for D=1 and C=1
     Ccutlass::ComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, T_c_ >, Shape< 1, S_h_, S_w_, S_c_ > >Specialization for D=1
     Ccutlass::platform::conditional< B, T, F >Std::conditional (true specialization)
     Ccutlass::platform::conditional< false, T, F >Std::conditional (false specialization)
     Ccutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIteratorA const iterator implementing Predicate Iterator Concept enabling sequential read-only access to prediactes
     Ccutlass::ConstPredicateTileAdapter< PredicateVector_, Iterations_ >Adapter to enable random access to predicates via logical coordinate within a tile
     Ccutlass::Convert< InputFragment_, OutputFragment_ >
     Ccutlass::Convert< Fragment< InputScalar_, kScalars_ >, Fragment< OutputScalar_, kScalars_ > >
     Ccutlass::Coord< N_ >Statically-sized array specifying Coords within a tensor
     Ccutlass::Coord< 4 >
     Ccutlass::Coord< Rank >
     Ccutlass::Copy< Fragment_ >
     Ccutlass::platform::default_delete< T >Default deleter
     Ccutlass::platform::default_delete< T[]>Partial specialization for deleting array types
     Ccutlass::divide_assert< Dividend, Divisor >
     Ccutlass::platform::is_base_of_helper< BaseT, DerivedT >::dummy< B, D >
     Ccutlass::platform::enable_if< C, T >Std::enable_if (true specialization)
     Ccutlass::platform::enable_if< false, T >Std::enable_if (false specialization)
     Ccutlass::Extent< T >Returns the extent of a scalar or vector
     Ccutlass::Extent< Vector< T, Lanes > >Returns the number of lanes of a vector if need be
     Ccutlass::Extent< Vector< T, Lanes > const >Returns the number of lanes of a vector if need be
     Ccutlass::FragmentConstIterator< Fragment_, Iterations_, AccessType_ >
     Ccutlass::FragmentIterator< Fragment_, Iterations_, AccessType_ >A template defining Fragment Iterator Concept
     Ccutlass::FragmentLoad< kIteratorFragment, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
     Ccutlass::FragmentLoad< IteratorFragment::kScalar, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
     Ccutlass::FragmentLoad< IteratorFragment::kWmmaMatrix, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
     Ccutlass::gemm::FragmentMultiplyAdd< Scalar_ >
     Ccutlass::gemm::FragmentMultiplyAdd< half >
     Ccutlass::FragmentStore< kIteratorFragment, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
     Ccutlass::FragmentStore< IteratorFragment::kScalar, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
     Ccutlass::FragmentStore< IteratorFragment::kWmmaMatrix, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
     Ccutlass::gemm::Gemm< GemmTraits_ >
     Ccutlass::gemm::GemmConfig< ScalarA_, ScalarB_, ScalarC_, ScalarD_, OutputTile_, MultiplyAdd_, kScalarsPerLdgA_, kScalarsPerStsA_, kScalarsPerLdsA_, kScalarsPerLdgB_, kScalarsPerStsB_, kScalarsPerLdsB_, kScalarsPerLdgCAndStgD_, kScalarsPerStsD_, kScalarsPerLdsD_, kStages_ >
     Ccutlass::gemm::GemmConfig< double, double, double, double, OutputTile_, ThreadMultiplyAdd< AccumulatorsPerThread_, Shape< 1, 4, 8 >, double, double, double >, kScalarsPerLdgA_, kScalarsPerLdgA_, 2, kScalarsPerLdgB_, kScalarsPerLdgB_, 2, 1, 2, 1, 2 >
     Ccutlass::gemm::GemmConfig< float, float, float, float, OutputTile_, ThreadMultiplyAdd< AccumulatorsPerThread_, Shape< 1, 4, 8 >, float, float, float >, kScalarsPerLdgA_, kScalarsPerLdgA_, 4, kScalarsPerLdgB_, kScalarsPerLdgB_, 4, 1, 4, 1, 2 >
     Ccutlass::gemm::GemmConfig< half, half, half, half, OutputTile_, ThreadMultiplyAdd< AccumulatorsPerThread_, Shape< 1, 4, 8 >, half, half, half >, kScalarsPerLdgA_, kScalarsPerLdgA_, 8, kScalarsPerLdgB_, kScalarsPerLdgB_, 8, 2, 8, 2, 2 >
     Ccutlass::gemm::GemmConfig< int8_t, int8_t, int8_t, int8_t, OutputTile_, ThreadMultiplyAdd< AccumulatorsPerThread_, Shape< 1, 4, 8 >, int8_t, int8_t, int >, 4, 4, 16, 4, 4, 16, 4, 4, 4, 2 >
     Ccutlass::gemm::GemmConfig< int8_t, int8_t, ScalarD_, ScalarD_, OutputTile_, ThreadMultiplyAdd< AccumulatorsPerThread_, Shape< 1, 4, 8 >, int8_t, int8_t, int >, 4, 4, 16, 4, 4, 16, 1, 4, 1, 2 >
     Ccutlass::gemm::GemmDesc< Scalar_, Index_ >
     Ccutlass::gemm::GemmEpilogue< GemmEpilogueTraits_ >
     Ccutlass::gemm::GemmEpilogueTraits< OutputTile_, Accumulators_, GlobalLoadIteratorC_, GlobalTransformerC_, GlobalTransformerD_, GlobalStoreIteratorD_, SharedStoreIteratorD_, SharedStoreTransformerD_, SharedLoadIteratorD_, Iterations_, Delta_, Functor_, Index_ >
     Ccutlass::gemm::GemmEpilogueTraits< GemmConfig_::OutputTile, GemmConfig_::Accumulators, Helper_::GlobalLoadIteratorC, Helper_::GlobalTransformerC, Helper_::GlobalTransformerD, Helper_::GlobalStoreIteratorD, Helper_::SharedStoreIteratorD, Helper_::SharedStoreTransformerD, Helper_::SharedLoadIteratorD, Helper_::Iterations, Helper_::Delta, EpilogueFunctor_, Index_ >
     Ccutlass::gemm::GemmEpilogueTraits< IgemmConfig_::OutputTile, IgemmConfig_::Accumulators, Helper_::GlobalLoadIteratorC, Helper_::GlobalTransformerC, Helper_::GlobalTransformerD, Helper_::GlobalStoreIteratorD, Helper_::SharedStoreIteratorD, Helper_::SharedStoreTransformerD, Helper_::SharedLoadIteratorD, Helper_::Iterations, Helper_::Delta, EpilogueFunctor_, Index_ >
     Ccutlass::gemm::GemmEpilogueTraitsHelper< GemmConfig_, EpilogueFunctor_, Index_ >
     Ccutlass::gemm::GemmEpilogueTraitsHelper< IgemmConfig_, EpilogueFunctor_, Index_ >
     Ccutlass::gemm::GemmGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ >
     Ccutlass::gemm::GemmGlobalTileTraits< GemmOperand::kC, MatrixLayout::kColumnMajor, Scalar_, Tile_, Threads_, kAccessSize_ >
     Ccutlass::gemm::GemmMultiplicandTraits< ThreadBlockTile_, Usage, Layout >
     Ccutlass::GemmOperandGemm operand - D = A * B + C
     Ccutlass::gemm::GemmOperandTraitsAb< kOperand_, kLayout_ >Helper to describe attributes of GEMM matrix operands
     Ccutlass::gemm::GemmSharedLoadTileATraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, InstructionShape_, kStages_, kScalarsPerLds_, kSkew_ >
     Ccutlass::gemm::GemmSharedLoadTileBTraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, InstructionShape_, kStages_, kScalarsPerLds_, kSkew_ >
     Ccutlass::gemm::GemmSharedLoadTileDTraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, kTileH_, kScalarsPerLds_, kSkew_ >
     Ccutlass::gemm::GemmSharedStoreTileAbTraits< Scalar_, Tile_, Threads_, kScalarsPerSts_ >
     Ccutlass::gemm::GemmSharedStoreTileDTraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, kScalarsPerSts_, kSkew_ >
     Ccutlass::gemm::GemmSharedStoreWithSkewTileAbTraits< Scalar_, Tile_, Threads_, kScalarsPerSts_, kSkew_ >
     Ccutlass::gemm::GemmTileTraitsHelperA< Kind, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperA< kLayout_, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperB< Kind, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperB< kLayout_, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_ >
     Ccutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >
     Ccutlass::gemm::GemmTraits< GemmConfig_, Helper_::GlobalLoadStreamA, Helper_::GlobalLoadStreamB, Helper_::SharedLoadStreamA, Helper_::SharedLoadStreamB, Epilogue_, IdentityBlockSwizzle, Index_, ClearAccumulators< GemmConfig_::Accumulators::Element > >
     Ccutlass::gemm::GemmTraits< GemmConfig_, SimplifiedGemmTraitsHelper< GemmTileTraitsHelperA< kLayoutA_, GemmConfig_ >, GemmTileTraitsHelperB< kLayoutB_, GemmConfig_ >, Index_ > ::GlobalLoadStreamA, SimplifiedGemmTraitsHelper< GemmTileTraitsHelperA< kLayoutA_, GemmConfig_ >, GemmTileTraitsHelperB< kLayoutB_, GemmConfig_ >, Index_ > ::GlobalLoadStreamB, SimplifiedGemmTraitsHelper< GemmTileTraitsHelperA< kLayoutA_, GemmConfig_ >, GemmTileTraitsHelperB< kLayoutB_, GemmConfig_ >, Index_ > ::SharedLoadStreamA, SimplifiedGemmTraitsHelper< GemmTileTraitsHelperA< kLayoutA_, GemmConfig_ >, GemmTileTraitsHelperB< kLayoutB_, GemmConfig_ >, Index_ > ::SharedLoadStreamB, GemmEpilogue< GemmEpilogueTraits_ >, IdentityBlockSwizzle, Index_, ClearAccumulators< GemmConfig_::Accumulators::Element > >
     Ccutlass::gemm::GemmTraits< Helper_::GemmConfig, Helper_::GlobalLoadStreamA, Helper_::GlobalLoadStreamB, Helper_::SharedLoadStreamA, Helper_::SharedLoadStreamB, Helper_::Epilogue, IdentityBlockSwizzle, Index_, Helper_::ClearAccumulators >
     Ccutlass::gemm::GetExtent< kOperand_, Tile_ >
     Ccutlass::gemm::GetExtent< GemmOperand::kA, Tile_ >
     Ccutlass::gemm::GetExtent< GemmOperand::kB, Tile_ >
     Ccutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >::GlobalLoadStreamAssemble the global load streams for A/B
     Ccutlass::gemm::GlobalLoadStreamBase< LoadIterator_, StoreIterator_, Transformer_ >
     Ccutlass::platform::greater< T >Std::greater
     Ccutlass::gemm::HgemmSwizzle< GlobalIterator_ >
     Ccutlass::gemm::HgemmTraitsHelper< kLayoutA_, kLayoutB_, OutputTile_, EpilogueFunctor_, AccumulatorsPerThread_, kScalarsPerLdgA_, kScalarsPerLdgB_, Index_ >
     Ccutlass::gemm::HgemmTransformerA< kLayout_, Iterator_ >
     Ccutlass::gemm::HgemmTransformerA< MatrixLayout::kColumnMajor, Iterator_ >
     Ccutlass::gemm::HgemmTransformerA< MatrixLayout::kRowMajor, Iterator_ >
     Ccutlass::gemm::HgemmTransformerB< kLayout_, Iterator_ >
     Ccutlass::gemm::HgemmTransformerB< MatrixLayout::kColumnMajor, Iterator_ >
     Ccutlass::gemm::HgemmTransformerB< MatrixLayout::kRowMajor, Iterator_ >
     Ccutlass::IdentityDescribes identity elements
     Ccutlass::gemm::IdentityBlockSwizzle
     Ccutlass::gemm::IgemmEpilogueScalar< ScalarD_ >
     Ccutlass::gemm::IgemmEpilogueScalar< int >
     Ccutlass::gemm::IgemmFloatToInt8Converter< kElements_ >
     Ccutlass::gemm::IgemmGlobalLoadTransformer< InputFragment_, OutputScalar_ >
     Ccutlass::gemm::IgemmGlobalLoadTransformer< Fragment< int8_t, kElements_ >, float >
     Ccutlass::gemm::IgemmGlobalStoreTransformer< InputScalar_, OutputFragment_ >
     Ccutlass::gemm::IgemmGlobalStoreTransformer< float, Fragment< int8_t, kElements_ > >
     Ccutlass::gemm::IgemmInt8ToFloatConverter< kElements_ >
     Ccutlass::gemm::IgemmSharedStoreTransformer< InputScalar_, OutputFragment_ >
     Ccutlass::gemm::IgemmSwizzle< GlobalIterator_ >
     Ccutlass::gemm::IgemmTraitsHelper< kLayoutA_, kLayoutB_, OutputTile_, ScalarD_, EpilogueFunctor_, AccumulatorsPerThread_, Index_ >
     Ccutlass::gemm::IgemmTransformerA< kLayout_, Iterator_ >
     Ccutlass::gemm::IgemmTransformerA< MatrixLayout::kColumnMajor, Iterator_ >
     Ccutlass::gemm::IgemmTransformerA< MatrixLayout::kRowMajor, Iterator_ >
     Ccutlass::gemm::IgemmTransformerB< kLayout_, Iterator_ >
     Ccutlass::gemm::IgemmTransformerB< MatrixLayout::kColumnMajor, Iterator_ >
     Ccutlass::gemm::IgemmTransformerB< MatrixLayout::kRowMajor, Iterator_ >
     Ccutlass::platform::integral_constant< value_t, V >Std::integral_constant
     Ccutlass::platform::integral_constant< bool, V >
     Ccutlass::platform::integral_constant< bool,(is_arithmetic< T >::value||is_void< T >::value||is_same< nullptr_t, remove_cv< T >::type >::value)>
     Ccutlass::platform::integral_constant< bool,(is_base_of_helper< remove_cv< BaseT >::type, remove_cv< DerivedT >::type >::value)||(is_same< remove_cv< BaseT >::type, remove_cv< DerivedT >::type >::value)>
     Ccutlass::platform::integral_constant< bool,(is_fundamental< T >::value||is_pointer< T >::value)>
     Ccutlass::platform::integral_constant< bool,(is_integral< T >::value||is_floating_point< T >::value)>
     Ccutlass::platform::integral_constant< bool,(is_same< float, remove_cv< T >::type >::value||is_same< double, remove_cv< T >::type >::value)>
     Ccutlass::platform::integral_constant< bool,(N &(N - 1))==0 >
     Ccutlass::platform::is_base_of_helper< BaseT, DerivedT >Helper for std::is_base_of
     Ccutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::IteratorAn iterator implementing Predicate Iterator Concept enabling sequential read and write access to predicates
     Ccutlass::IteratorAdvanceSpecifies dimension in which post-increment accesses advance
     Ccutlass::IteratorFragmentSpecifies whether iterator storage fragment consists of Scalar values or WMMA matrix
     Ccutlass::platform::less< T >Std::less
     Ccutlass::gemm::LinearScaling< Scalar_, FragmentMultiplyAdd_ >Functor to compute linear combination of fragments
     Ccutlass::Load< Scalar_, Lanes_, Memory_, bool, size_t >
     Ccutlass::Load< double, 2, Memory_, true, 16 >
     Ccutlass::Load< Scalar_, Lanes_, Memory_, true, 16 >
     Ccutlass::Load< Scalar_, Lanes_, Memory_, true, 4 >
     Ccutlass::Load< Scalar_, Lanes_, Memory_, true, 8 >
     Ccutlass::log2_down< N, CurrentVal, Count >
     Ccutlass::log2_down< N, 1, Count >
     Ccutlass::log2_up< N, CurrentVal, Count >
     Ccutlass::log2_up< N, 1, Count >
     Ccutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >::MainLoopSharedStorage
     Ccutlass::MatrixLayoutDescribes layouts of matrices
     Ccutlass::MemorySpaceEnum to specify which memory space data resides in
     Ccutlass::platform::nullptr_tStd::nullptr_t
     Ccutlass::platform::alignment_of< value_t >::pad
     Ccutlass::gemm::WmmaGemmGlobalIteratorCd< TileTraits_, Index_ >::ParamsThe params
     CParams
     Ccutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >::ParamsThe params
     Ccutlass::gemm::GlobalLoadStreamBase< LoadIterator_, StoreIterator_, Transformer_ >::ParamsThe params
     Ccutlass::TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, IteratorFragment_, Skew_ >::ParamsParameters to the iterator
     Ccutlass::gemm::GemmGlobalIteratorCd< TileTraits_, Index_ >::ParamsThe params
     Ccutlass::gemm::GemmEpilogueTraits< OutputTile_, Accumulators_, GlobalLoadIteratorC_, GlobalTransformerC_, GlobalTransformerD_, GlobalStoreIteratorD_, SharedStoreIteratorD_, SharedStoreTransformerD_, SharedLoadIteratorD_, Iterations_, Delta_, Functor_, Index_ >::ParamsThe params
     Ccutlass::gemm::SharedLoadStream< Iterator_, Transformer_ >::ParamsThe params
     Ccutlass::gemm::LinearScaling< Scalar_, FragmentMultiplyAdd_ >::ParamsThe parameters
     Ccutlass::platform::plus< T >Platform::plus
     Ccutlass::PredicateTileAdapter< PredicateVector_, Iterations_ >Adapter to enable random access to predicates via logical coordinate within a tile
     Ccutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >Statically sized array of bits implementing
     Ccutlass::PredicateVector< Base::Iterations::kW >
     Ccutlass::PredicateVector< ShapeCount< typename Base::Iterations >::kCount >
     Ccutlass::gemm::ProjectOperand< operand, Kstrided >
     Ccutlass::gemm::ProjectOperand< GemmOperand::kA, Kstrided >Project A operand - (0, K, M)
     Ccutlass::gemm::ProjectOperand< GemmOperand::kB, Kstrided >Project B operand - (0, K, N)
     Ccutlass::gemm::ProjectOperand< GemmOperand::kC, true >Project C operand - (0, N, M)
     Ccutlass::gemm::ProjectOperand< GemmOperand::kD, true >Project D operand - (0, N, M)
     Ccutlass::platform::remove_const< T >Std::remove_const (non-const specialization)
     Ccutlass::platform::remove_const< const T >Std::remove_const (const specialization)
     Ccutlass::platform::remove_cv< T >Std::remove_cv
     Ccutlass::platform::remove_volatile< T >Std::remove_volatile (non-volatile specialization)
     Ccutlass::platform::remove_volatile< volatile T >Std::remove_volatile (volatile specialization)
     Ccutlass::gemm::ReshapeThreads< Tile_, Threads_, bool >
     Ccutlass::gemm::ReshapeThreads< Tile_, Threads_, true >
     Ccutlass::ReshapeTile< Tile_, kAccessSize_, bool >
     Ccutlass::ReshapeTile< Tile_, kAccessSize_, true >
     Ccutlass::Shape< kD_, kH_, kW_, kC_ >A Shape implementing Layout Concept describing the dimensions of a cube
     Ccutlass::ShapeAdd< A_, B_ >
     Ccutlass::ShapeCount< Shape >Compute derived counted of a Layout Concept based class
     Ccutlass::ShapeDiv< A_, B_ >
     Ccutlass::ShapeMax< A_, B_ >
     Ccutlass::ShapeMin< A_, B_ >
     Ccutlass::ShapeMul< A_, B_ >
     Ccutlass::ShapeScale< A_, kScale_ >
     Ccutlass::ShapeStrides< Shape_ >
     Ccutlass::ShapeSub< A_, B_ >
     Ccutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >::SharedLoadStreamAssemble the shared load stream for A/B
     Ccutlass::gemm::SharedLoadStream< Iterator_, Transformer_ >
     Ccutlass::gemm::ClearAccumulators< Scalar_, kLanes_ >::SharedStorageThe shared storage
     Ccutlass::gemm::GemmEpilogueTraits< OutputTile_, Accumulators_, GlobalLoadIteratorC_, GlobalTransformerC_, GlobalTransformerD_, GlobalStoreIteratorD_, SharedStoreIteratorD_, SharedStoreTransformerD_, SharedLoadIteratorD_, Iterations_, Delta_, Functor_, Index_ >::SharedStorageThe shared memory to swizzle the data in the epilogue
     Ccutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >::SharedStorageThe storage in shared memory
     Ccutlass::gemm::GlobalLoadStreamBase< LoadIterator_, StoreIterator_, Transformer_ >::SharedStorageThe storage in shared memory needed by that stream
     Ccutlass::gemm::SimplifiedGemmTraitsHelper< GemmTileTraitsHelperA_, GemmTileTraitsHelperB_, Index_ >
     Ccutlass::sqrt_est< N >
     Ccutlass::StorageType< kAlignment_ >
     Ccutlass::StorageType< 1 >
     Ccutlass::StorageType< 2 >
     Ccutlass::StorageType< 4 >
     Ccutlass::Store< Scalar_, Lanes_, Memory_, bool, size_t >
     Ccutlass::Store< double, 2, Memory_, true, 16 >
     Ccutlass::Store< Scalar_, Lanes_, Memory_, true, 16 >
     Ccutlass::Store< Scalar_, Lanes_, Memory_, true, 4 >
     Ccutlass::Store< Scalar_, Lanes_, Memory_, true, 8 >
     Ccutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >::StreamSharedStorage< GlobalLoadStream_, SharedLoadStream_ >
     Ccutlass::gemm::GemmEpilogueTraits< OutputTile_, Accumulators_, GlobalLoadIteratorC_, GlobalTransformerC_, GlobalTransformerD_, GlobalStoreIteratorD_, SharedStoreIteratorD_, SharedStoreTransformerD_, SharedLoadIteratorD_, Iterations_, Delta_, Functor_, Index_ >::StreamSharedStorageThe shared memory storage to exchange data
     Ccutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >::StreamSharedStorage< GlobalLoadStreamA, SharedLoadStreamA >
     Ccutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >::StreamSharedStorage< GlobalLoadStreamB, SharedLoadStreamB >
     Ccutlass::TensorRef< Storage_, Rank_ >Structure modeling a pointer and stride into a tensor
     Ccutlass::TensorRef< T, 4 >
     Ccutlass::gemm::ThreadMultiplyAdd< AccumulatorsPerThread_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_ >Template performing matrix multiply-add operation within a thread
     Ccutlass::gemm::ThreadMultiplyAdd< AccumulatorsPerThread_, ThreadsPerWarp_, half, half, half >Template performing matrix multiply-add operation within a thread
     Ccutlass::gemm::ThreadMultiplyAdd< AccumulatorsPerThread_, ThreadsPerWarp_, int8_t, int8_t, int >Template performing matrix multiply-add operation within a thread
     Ccutlass::gemm::GemmSharedLoadTileBTraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, InstructionShape_, kStages_, kScalarsPerLds_, kSkew_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmGlobalTileCdTraits< Scalar_, Tile_, Threads_, kStrideH_, kAccessSize_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::IgemmContiguousGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmSharedLoadTileDTraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, kTileH_, kScalarsPerLds_, kSkew_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmSharedLoadTileATraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, InstructionShape_, kStages_, kScalarsPerLds_, kSkew_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmSharedStoreTileDTraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, kScalarsPerSts_, kSkew_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::HgemmCrosswiseGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmSharedStoreTileAbTraits< Scalar_, Tile_, Threads_, kScalarsPerSts_ >::ThreadOffset
     Ccutlass::TileTraitsWarpRake< Tile_, Threads >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmSharedStoreWithSkewTileAbTraits< Scalar_, Tile_, Threads_, kScalarsPerSts_, kSkew_ >::ThreadOffset
     Ccutlass::gemm::WmmaGemmGlobalIteratorCdTraits< Scalar_, Tile_, Threads_, kAccessSize_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::TiledThreadOffset< ThreadShape >Basic thread offset function computed from a thread shape
     Ccutlass::TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, IteratorFragment_, Skew_ >Iterator for accessing a stripmined tile in memory
     Ccutlass::TileIteratorBase< TileTraits_, TileTraits_::Scalar, Advance_, MemorySpace, Index_, TileTraits_::Scalar, IteratorFragment::kScalar, Shape< 0, 0, 0, 0 > >
     Ccutlass::TileIteratorBase< TileTraits_, TileTraits_::Scalar, IteratorAdvance::kH, MemorySpace::kGlobal, Index_ >
     Ccutlass::TileTraits< Tile_, Delta_, Iterations_, ThreadOffset_ >A template defining Tile Traits Concept
     Ccutlass::TileTraitsContiguousMajor< Tile_, Threads >
     Ccutlass::TileTraitsStandard< Tile_, Threads >Chooses 'best' shape to enable warp raking along contiguous dimension if possible
     Ccutlass::TileTraitsStrideMajor< Tile_, Threads >
     Ccutlass::TileTraitsWarpRake< Tile_, Threads >Tiling in which warps rake across the contiguous dimension
     Ccutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::TrivialIteratorIterator that always returns true
     Ccutlass::TrivialPredicateTileAdapterAlways returns true predicate
     Ccutlass::platform::unique_ptr< T, Deleter >Std::unique_ptr
     Ccutlass::Vector< Scalar_, kLanes_ >
     Ccutlass::Vector< half, kLanes_ >
     Ccutlass::Vectorize< Element_, kLanes_ >
     Ccutlass::Vectorize< Element_, 1 >
     Ccutlass::VectorTraits< T >Traits describing properties of vectors and scalar-as-vectors
     Ccutlass::VectorTraits< Vector< T, Lanes > >Partial specialization for actual cutlass::Vector
     Ccutlass::VectorTraits< Vector< T, Lanes > const >Partial specialization for actual cutlass::Vector
     Ccutlass::bin1_t
     Ccutlass::gemm::ClearAccumulators< Scalar_, kLanes_ >
     Ccutlass::MatrixLayout::ColumnMajorMapping function for column-major matrices
     Ccutlass::MatrixLayout::ColumnMajorBlockLinear< BlockRows, BlockColumns >
     Ccutlass::gemm::ColumnMajorBlockSwizzle< groupCols, swDirection >
     Ccutlass::MatrixLayout::ColumnMajorInterleaved< Interleave >
     Ccutlass::platform::complex< T >
     Ccutlass::ComputeOffsetFromShape< Shape_ >Compute the offset for the given coordinates in a cube
     Ccutlass::ComputeOffsetFromStrides< Strides_ >Compute the offset for the given coordinates in a cube
     Ccutlass::ComputeThreadOffsetFromStrides< Threads_, Strides_ >Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_. Afterwards compute the offset of those coordinates using Strides_
     Ccutlass::ComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, 1 >, Shape< 1, S_h_, S_w_, 1 > >Specialization for D=1 and C=1
     Ccutlass::ComputeThreadOffsetFromStrides< Shape< 1, T_h_, T_w_, T_c_ >, Shape< 1, S_h_, S_w_, S_c_ > >Specialization for D=1
     Ccutlass::platform::conditional< B, T, F >Std::conditional (true specialization)
     Ccutlass::platform::conditional< false, T, F >Std::conditional (false specialization)
     Ccutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::ConstIteratorA const iterator implementing Predicate Iterator Concept enabling sequential read-only access to prediactes
     Ccutlass::TensorRefBatchStrided< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorConstant iterator over tensors implied by TensorRefBatchStrided
     Ccutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >::ConstIteratorTensorRefIterator over TensorRef objects in TensorRefArray
     Ccutlass::ConstPredicateTileAdapter< PredicateVector_, Iterations_ >Adapter to enable random access to predicates via logical coordinate within a tile
     Ccutlass::MatrixLayout::ContiguousLayout
     Ccutlass::Convert< InputFragment_, OutputFragment_ >
     Ccutlass::Convert< Fragment< InputScalar_, kScalars_ >, Fragment< OutputScalar_, kScalars_ > >
     Ccutlass::Coord< Rank_, Index_ >Statically-sized array specifying Coords within a tensor
     Ccutlass::Coord< 2, int >
     Ccutlass::Coord< 3 >
     Ccutlass::Coord< 4 >
     Ccutlass::Coord< 4, Index_ >
     Ccutlass::Coord< 4, int >
     Ccutlass::Coord< kStorageRank - 1 >
     Ccutlass::Copy< Fragment_ >
     CDebugType< T >
     CDebugValue< Value >
     Ccutlass::platform::default_delete< T >Default deleter
     Ccutlass::platform::default_delete< T[]>Partial specialization for deleting array types
     Ccutlass::divide_assert< Dividend, Divisor >
     Ccutlass::platform::is_base_of_helper< BaseT, DerivedT >::dummy< B, D >
     Ccutlass::DumpType< T >
     Ccutlass::platform::enable_if< C, T >Std::enable_if (true specialization)
     Ccutlass::platform::enable_if< false, T >Std::enable_if (false specialization)
     Ccutlass::Extent< T >Returns the extent of a scalar or vector
     Ccutlass::Extent< Vector< T, Lanes > >Returns the number of lanes of a vector if need be
     Ccutlass::Extent< Vector< T, Lanes > const >Returns the number of lanes of a vector if need be
     Ccutlass::FragmentConstIterator< Fragment_, Iterations_, AccessType_ >
     Ccutlass::FragmentElementTypeSpecifies whether iterator storage fragment consists of Scalar values or WMMA matrix
     Ccutlass::FragmentIterator< Fragment_, Iterations_, AccessType_ >A template defining Fragment Iterator Concept
     Ccutlass::gemm::FragmentMultiplyAdd< ScalarAlphaBeta_, ScalarAccum_, fragMul2 >
     Ccutlass::gemm::FragmentMultiplyAdd< half, half, true >
     Ccutlass::gemm::Gemm< GemmTraits_ >
     Ccutlass::gemm::GemmConfig< ScalarA_, ScalarB_, ScalarC_, ScalarD_, OutputTile_, MultiplyAdd_, kScalarsPerLdgA_, kScalarsPerStsA_, kScalarsPerLdsA_, kScalarsPerLdgB_, kScalarsPerStsB_, kScalarsPerLdsB_, kScalarsPerLdgCAndStgD_, kScalarsPerStsD_, kScalarsPerLdsD_, kStages_, kResidueSeparate_, kResidueInProlog_, kLaunchBounds_ >
     Ccutlass::gemm::GemmConfig< double, double, double, double, OutputTile_, ThreadMultiplyAdd< ThreadGemmShape_, Shape< 1, 4, 8 >, double, double, double >, kScalarsPerLdgA_, kScalarsPerLdgA_, 2, kScalarsPerLdgB_, kScalarsPerLdgB_, 2, 1, 2, 1, 2, false, false, false >
     Ccutlass::gemm::GemmConfig< float, float, float, float, OutputTile_, ThreadMultiplyAdd< ThreadGemmShape_, Shape< 1, 4, 8 >, float, float, float >, kScalarsPerLdgA_, kScalarsPerLdgA_, 4, kScalarsPerLdgB_, kScalarsPerLdgB_, 4, 1, 4, 1, 2, false, true, kLaunchBounds >
     Ccutlass::gemm::GemmConfig< half, half, half, half, OutputTile_, ThreadMultiplyAdd< ThreadGemmShape_, Shape< 1, 4, 8 >, half, half, half >, kScalarsPerLdgA_, kScalarsPerLdgA_, 8, kScalarsPerLdgB_, kScalarsPerLdgB_, 8, 2, 8, 2, 2, false, true, false >
     Ccutlass::gemm::GemmConfig< int8_t, int8_t, int8_t, int8_t, OutputTile_, ThreadMultiplyAdd< ThreadGemmShape_, Shape< 1, 4, 8 >, int8_t, int8_t, int >, 4, 4, 16, 4, 4, 16, 4, 4, 4, 2, false, true, false >
     Ccutlass::gemm::GemmConfig< int8_t, int8_t, ScalarD_, ScalarD_, OutputTile_, ThreadMultiplyAdd< ThreadGemmShape_, Shape< 1, 4, 8 >, int8_t, int8_t, int >, 4, 4, 16, 4, 4, 16, 1, 4, 1, 2, false, false, false >
     Ccutlass::gemm::GemmConfig< ScalarA_, ScalarB_, ScalarC_, ScalarD_, OutputTile_, ThreadMultiplyAdd< ThreadGemmShape_, Shape< 1, 4, 8 >, ScalarA_, ScalarB_, float >, kScalarsPerLdgA_, kScalarsPerLdgA_, 4, kScalarsPerLdgB_, kScalarsPerLdgB_, 4, 1, 4, 1, 2 >
     Ccutlass::gemm::GemmDesc< AType_, BType_, CType_, DType_, SType_, Index_ >GEMM problem description
     Ccutlass::gemm::GemmEpilogue< GemmEpilogueTraits_ >
     Ccutlass::gemm::GemmEpilogueTraits< OutputTile_, Accumulators_, GlobalLoadIteratorC_, GlobalTransformerC_, GlobalTransformerD_, GlobalStoreIteratorD_, SharedStoreIteratorD_, SharedStoreTransformerD_, SharedLoadStreamD_, Iterations_, Delta_, Functor_, Index_ >
     Ccutlass::gemm::GemmEpilogueTraits< GemmConfig_::OutputTile, GemmConfig_::Accumulators, Helper_::GlobalLoadIteratorC, Helper_::GlobalTransformerC, Helper_::GlobalTransformerD, Helper_::GlobalStoreIteratorD, Helper_::SharedStoreIteratorD, Helper_::SharedStoreTransformerD, Helper_::SharedLoadStreamD, Helper_::Iterations, Helper_::Delta, EpilogueFunctor_, Index_ >
     Ccutlass::gemm::GemmEpilogueTraits< IgemmConfig_::OutputTile, IgemmConfig_::Accumulators, Helper_::GlobalLoadIteratorC, Helper_::GlobalTransformerC, Helper_::GlobalTransformerD, Helper_::GlobalStoreIteratorD, Helper_::SharedStoreIteratorD, Helper_::SharedStoreTransformerD, Helper_::SharedLoadStreamD, Helper_::Iterations, Helper_::Delta, EpilogueFunctor_, Index_ >
     Ccutlass::gemm::GemmEpilogueTraitsHelper< GemmConfig_, EpilogueFunctor_, Index_ >
     Ccutlass::gemm::GemmEpilogueTraitsHelper< IgemmConfig_, EpilogueFunctor_, Index_ >
     Ccutlass::gemm::GemmGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ >
     Ccutlass::gemm::GemmGlobalTileTraits< GemmOperand::kC, MatrixLayout::kColumnMajor, Scalar_, Tile_, Threads_, kAccessSize_ >
     Ccutlass::gemm::GemmMultiplicandTraits< ThreadBlockTile_, Usage, Layout >
     Ccutlass::GemmOperandGemm operand - D = A * B + C
     Ccutlass::gemm::GemmOperandTraitsAb< kOperand_, kLayout_ >Helper to describe attributes of GEMM matrix operands
     Ccutlass::gemm::GemmSharedLoadTileATraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, InstructionShape_, kStages_, kScalarsPerLds_, kSkew_ >
     Ccutlass::gemm::GemmSharedLoadTileBTraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, InstructionShape_, kStages_, kScalarsPerLds_, kSkew_ >
     Ccutlass::gemm::GemmSharedLoadTileDTraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, kTileH_, kScalarsPerLds_, kSkew_ >
     Ccutlass::gemm::GemmSharedStoreTileAbTraits< Scalar_, Tile_, Threads_, kScalarsPerSts_ >
     Ccutlass::gemm::GemmSharedStoreTileDTraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, kScalarsPerSts_, kSkew_ >
     Ccutlass::gemm::GemmSharedStoreWithSkewTileAbTraits< Scalar_, Tile_, Threads_, kScalarsPerSts_, kSkew_ >
     Ccutlass::gemm::GemmTileTraitsHelperA< Kind, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperA< kLayout_, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperB< Kind, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperB< kLayout_, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_ >
     Ccutlass::gemm::GemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_ >
     Ccutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >
     Ccutlass::gemm::GemmTraits< GemmConfig_, Helper_::GlobalLoadStreamA, Helper_::GlobalLoadStreamB, Helper_::SharedLoadStreamA, Helper_::SharedLoadStreamB, Epilogue_, IdentityBlockSwizzle, Index_, ClearAccumulators< GemmConfig_::Accumulators::Element > >
     Ccutlass::gemm::GemmTraits< GemmConfig_, SimplifiedGemmTraitsHelper< GemmTileTraitsHelperA< kLayoutA_, GemmConfig_ >, GemmTileTraitsHelperB< kLayoutB_, GemmConfig_ >, Index_ > ::GlobalLoadStreamA, SimplifiedGemmTraitsHelper< GemmTileTraitsHelperA< kLayoutA_, GemmConfig_ >, GemmTileTraitsHelperB< kLayoutB_, GemmConfig_ >, Index_ > ::GlobalLoadStreamB, SimplifiedGemmTraitsHelper< GemmTileTraitsHelperA< kLayoutA_, GemmConfig_ >, GemmTileTraitsHelperB< kLayoutB_, GemmConfig_ >, Index_ > ::SharedLoadStreamA, SimplifiedGemmTraitsHelper< GemmTileTraitsHelperA< kLayoutA_, GemmConfig_ >, GemmTileTraitsHelperB< kLayoutB_, GemmConfig_ >, Index_ > ::SharedLoadStreamB, GemmEpilogue< GemmEpilogueTraits_ >, IdentityBlockSwizzle, Index_, ClearAccumulators< GemmConfig_::Accumulators::Element > >
     Ccutlass::gemm::GemmTraits< Helper_::GemmConfig, Helper_::GlobalLoadStreamA, Helper_::GlobalLoadStreamB, Helper_::SharedLoadStreamA, Helper_::SharedLoadStreamB, Helper_::Epilogue, IdentityBlockSwizzle, Index_, Helper_::ClearAccumulators >
     Ccutlass::gemm::GetExtent< kOperand_, Tile_ >
     Ccutlass::gemm::GetExtent< GemmOperand::kA, Tile_ >
     Ccutlass::gemm::GetExtent< GemmOperand::kB, Tile_ >
     Ccutlass::gemm::GlobalLoadStream< Operand, LoadIterator_, StoreIterator_, Transformer_ >
     Ccutlass::gemm::GlobalLoadStreamPair< StreamA_, StreamB_, kResidueInProlog_ >Collect the global load streams for multiplicands
     Ccutlass::platform::greater< T >Std::greater
     Ccutlass::gemm::HgemmSwizzle< GlobalIterator_ >
     Ccutlass::gemm::HgemmTraitsHelper< kLayoutA_, kLayoutB_, OutputTile_, EpilogueFunctor_, ThreadGemmShape_, kScalarsPerLdgA_, kScalarsPerLdgB_, Index_ >
     Ccutlass::gemm::HgemmTransformerA< kLayout_, Iterator_ >
     Ccutlass::gemm::HgemmTransformerA< MatrixLayout::kColumnMajor, Iterator_ >
     Ccutlass::gemm::HgemmTransformerA< MatrixLayout::kRowMajor, Iterator_ >
     Ccutlass::gemm::HgemmTransformerB< kLayout_, Iterator_ >
     Ccutlass::gemm::HgemmTransformerB< MatrixLayout::kColumnMajor, Iterator_ >
     Ccutlass::gemm::HgemmTransformerB< MatrixLayout::kRowMajor, Iterator_ >
     Ccutlass::IdentityDescribes identity elements
     Ccutlass::gemm::IdentityBlockSwizzle
     Ccutlass::IdentityTensorMapFunc< Rank >
     Ccutlass::IdentityTensorMapFunc< Rank_ >
     Ccutlass::gemm::IgemmEpilogueScalar< ScalarD_ >
     Ccutlass::gemm::IgemmEpilogueScalar< int >
     Ccutlass::gemm::IgemmFloatToInt8Converter< kElements_ >
     Ccutlass::gemm::IgemmGlobalLoadTransformer< InputFragment_, OutputScalar_ >
     Ccutlass::gemm::IgemmGlobalLoadTransformer< Fragment< int8_t, kElements_ >, float >
     Ccutlass::gemm::IgemmGlobalStoreTransformer< InputScalar_, OutputFragment_ >
     Ccutlass::gemm::IgemmGlobalStoreTransformer< float, Fragment< int8_t, kElements_ > >
     Ccutlass::gemm::IgemmInt8ToFloatConverter< kElements_ >
     Ccutlass::gemm::IgemmSharedStoreTransformer< InputScalar_, OutputFragment_ >
     Ccutlass::gemm::IgemmSwizzle< GlobalIterator_ >
     Ccutlass::gemm::IgemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_, Index_ >
     Ccutlass::gemm::IgemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_, Index_ >
     Ccutlass::gemm::IgemmTraitsHelper< kLayoutA_, kLayoutB_, OutputTile_, ScalarD_, EpilogueFunctor_, ThreadGemmShape_, Index_ >
     Ccutlass::gemm::IgemmTransformerA< kLayout_, Iterator_ >
     Ccutlass::gemm::IgemmTransformerA< MatrixLayout::kColumnMajor, Iterator_ >
     Ccutlass::gemm::IgemmTransformerA< MatrixLayout::kRowMajor, Iterator_ >
     Ccutlass::gemm::IgemmTransformerB< kLayout_, Iterator_ >
     Ccutlass::gemm::IgemmTransformerB< MatrixLayout::kColumnMajor, Iterator_ >
     Ccutlass::gemm::IgemmTransformerB< MatrixLayout::kRowMajor, Iterator_ >
     Ccutlass::int4_t
     Ccutlass::platform::integral_constant< value_t, V >Std::integral_constant
     Ccutlass::platform::integral_constant< bool, V >
     Ccutlass::platform::integral_constant< bool,(is_arithmetic< T >::value||is_void< T >::value||is_same< nullptr_t, remove_cv< T >::type >::value)>
     Ccutlass::platform::integral_constant< bool,(is_base_of_helper< remove_cv< BaseT >::type, remove_cv< DerivedT >::type >::value)||(is_same< remove_cv< BaseT >::type, remove_cv< DerivedT >::type >::value)>
     Ccutlass::platform::integral_constant< bool,(is_fundamental< T >::value||is_pointer< T >::value)>
     Ccutlass::platform::integral_constant< bool,(is_integral< T >::value||is_floating_point< T >::value)>
     Ccutlass::platform::integral_constant< bool,(is_same< float, remove_cv< T >::type >::value||is_same< double, remove_cv< T >::type >::value)>
     Ccutlass::platform::integral_constant< bool,(N &(N - 1))==0 >
     Ccutlass::platform::is_base_of_helper< BaseT, DerivedT >Helper for std::is_base_of
     Ccutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::IteratorAn iterator implementing Predicate Iterator Concept enabling sequential read and write access to predicates
     Ccutlass::IteratorAdvanceSpecifies dimension in which post-increment accesses advance
     Ccutlass::KernelLaunchConfigurationStructure containing the basic launch configuration of a CUDA kernel
     Ccutlass::gemm::Launch< Gemm, WithLaunchBounds >Partial specialization for launching the GEMM kernel with or without launch bounds
     Ccutlass::gemm::Launch< Gemm, false >Partial specialization for launching the GEMM kernel with or without launch bounds
     Ccutlass::platform::less< T >Std::less
     Ccutlass::gemm::LinearScaling< Scalar_, FragmentMultiplyAdd_ >Functor to compute linear combination of fragments
     Ccutlass::Load< Scalar_, kAccessSize, Memory_, kFragmentElementType, FragmentElement_, kStride, size >
     Ccutlass::Load< double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16 >
     Ccutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2 >Partial specialization for 16b loads
     Ccutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16 >
     Ccutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4 >
     Ccutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8 >
     Ccutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     Ccutlass::Load< Vector< bin1_t, 32 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     Ccutlass::Load< Vector< int4_t, 8 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     Ccutlass::Load< Vector< uint4_t, 8 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     Ccutlass::log2_down< N, CurrentVal, Count >
     Ccutlass::log2_down< N, 1, Count >
     Ccutlass::log2_up< N, CurrentVal, Count >
     Ccutlass::log2_up< N, 1, Count >
     Ccutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >::MainLoopSharedStorage
     Ccutlass::MatrixLayoutDefines data layouts of various matrix formats usable by TensorRef and other classes
     Ccutlass::MatrixTransformTransformation applied to matrix operands
     Ccutlass::Max< A, B >
     Ccutlass::MemorySpaceEnum to specify which memory space data resides in
     Ccutlass::Min< A, B >
     Ccutlass::platform::nullptr_tStd::nullptr_t
     Ccutlass::platform::alignment_of< value_t >::pad
     Ccutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::ParamsThe parameters
     Ccutlass::gemm::GlobalLoadStream< Operand, LoadIterator_, StoreIterator_, Transformer_ >::ParamsThe params
     Ccutlass::gemm::SharedStreamPair< StreamA_, StreamB_ >::ParamsParameters object passed to load iterators
     Ccutlass::ZipTileIterator< First_, Second_ >::ParamsParams object
     Ccutlass::gemm::LinearScaling< Scalar_, FragmentMultiplyAdd_ >::ParamsThe parameters
     Ccutlass::gemm::GlobalLoadStreamPair< StreamA_, StreamB_, kResidueInProlog_ >::ParamsParameters object
     Ccutlass::gemm::GemmGlobalIteratorCd< TileTraits_, Index_ >::ParamsThe params
     Ccutlass::gemm::GemmEpilogueTraits< OutputTile_, Accumulators_, GlobalLoadIteratorC_, GlobalTransformerC_, GlobalTransformerD_, GlobalStoreIteratorD_, SharedStoreIteratorD_, SharedStoreTransformerD_, SharedLoadStreamD_, Iterations_, Delta_, Functor_, Index_ >::ParamsThe params
     Ccutlass::TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ >::ParamsParameters to the iterator
     Ccutlass::TileLoadStream< Iterator_, Transformer_ >::ParamsParameters object used to construct generic load stream
     Ccutlass::TileStoreStream< Iterator_, Transformer_ >::ParamsParameters used to construct the stream
     Ccutlass::gemm::SharedLoadStream< Iterator_, Transformer_ >::ParamsThe params
     Ccutlass::platform::plus< T >Platform::plus
     Ccutlass::PredicateTileAdapter< PredicateVector_, Iterations_ >Adapter to enable random access to predicates via logical coordinate within a tile
     Ccutlass::TileLoadStream< Iterator_, Transformer_ >::PredicateVectorEmpty predicate vector struct
     Ccutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >Statically sized array of bits implementing
     Ccutlass::TileStoreStream< Iterator_, Transformer_ >::PredicateVectorEmpty predicate vector struct
     Ccutlass::PredicateVector< Base::Iterations::kW >
     Ccutlass::PredicateVector< ShapeCount< typename Base::Iterations >::kCount >
     Ccutlass::gemm::ProjectOperand< operand, Kstrided >
     Ccutlass::gemm::ProjectOperand< GemmOperand::kA, Kstrided >Project A operand - (0, K, M)
     Ccutlass::gemm::ProjectOperand< GemmOperand::kB, Kstrided >Project B operand - (0, K, N)
     Ccutlass::gemm::ProjectOperand< GemmOperand::kC, true >Project C operand - (0, N, M)
     Ccutlass::gemm::ProjectOperand< GemmOperand::kD, true >Project D operand - (0, N, M)
     Ccutlass::RegularTilePredicateFunctor< Delta_ >Functor computing a predicate given the logical position of an access
     Ccutlass::platform::remove_const< T >Std::remove_const (non-const specialization)
     Ccutlass::platform::remove_const< const T >Std::remove_const (const specialization)
     Ccutlass::platform::remove_cv< T >Std::remove_cv
     Ccutlass::platform::remove_volatile< T >Std::remove_volatile (non-volatile specialization)
     Ccutlass::platform::remove_volatile< volatile T >Std::remove_volatile (volatile specialization)
     Ccutlass::gemm::ReshapeThreads< Tile_, Threads_, bool >
     Ccutlass::gemm::ReshapeThreads< Tile_, Threads_, true >
     Ccutlass::ReshapeTile< Tile_, kAccessSize_, bool >
     Ccutlass::ReshapeTile< Tile_, kAccessSize_, true >
     Ccutlass::MatrixLayout::RowMajorMapping function for row-major matrices
     Ccutlass::MatrixLayout::RowMajorBlockLinear< BlockRows, BlockColumns >
     Ccutlass::gemm::RowMajorBlockSwizzle< groupRows, swDirection >
     Ccutlass::MatrixLayout::RowMajorInterleaved< Interleave >
     Ccutlass::ScalarIO< T >Helper to enable formatted printing of CUTLASS scalar types to an ostream
     Ccutlass::detail::ScalarOrPointer< Scalar_ >
     Ccutlass::detail::ScalarOrPointer< Scalar >
     Ccutlass::Shape< kD_, kH_, kW_, kC_ >A Shape implementing Layout Concept describing the dimensions of a cube
     Ccutlass::ShapeAdd< A_, B_ >
     Ccutlass::ShapeCount< Shape >Compute derived counted of a Layout Concept based class
     Ccutlass::ShapeDiv< A_, B_ >
     Ccutlass::ShapeDivCeiling< A_, B_ >
     Ccutlass::ShapeMax< A_, B_ >
     Ccutlass::ShapeMin< A_, B_ >
     Ccutlass::ShapeMul< A_, B_ >
     Ccutlass::ShapeScale< A_, kScale_ >
     Ccutlass::ShapeStrides< Shape_, elementsPerAccess >
     Ccutlass::ShapeSub< A_, B_ >
     Ccutlass::gemm::SharedLoadStream< Iterator_, Transformer_ >
     Ccutlass::gemm::GemmEpilogueTraits< OutputTile_, Accumulators_, GlobalLoadIteratorC_, GlobalTransformerC_, GlobalTransformerD_, GlobalStoreIteratorD_, SharedStoreIteratorD_, SharedStoreTransformerD_, SharedLoadStreamD_, Iterations_, Delta_, Functor_, Index_ >::SharedStorageThe shared memory to swizzle the data in the epilogue
     Ccutlass::gemm::GlobalLoadStreamPair< StreamA_, StreamB_, kResidueInProlog_ >::SharedStorageDefines a structure containing shared storage for each pair
     Ccutlass::gemm::GemmTraits< GemmConfig_, GlobalLoadStreamA_, GlobalLoadStreamB_, SharedLoadStreamA_, SharedLoadStreamB_, Epilogue_, BlockSwizzle_, Index_, ClearAccumulators_ >::SharedStorageThe storage in shared memory
     Ccutlass::gemm::GlobalLoadStream< Operand, LoadIterator_, StoreIterator_, Transformer_ >::SharedStorage
     Ccutlass::gemm::ClearAccumulators< Scalar_, kLanes_ >::SharedStorageThe shared storage
     Ccutlass::gemm::SharedStreamPair< StreamA_, StreamB_ >Collect the global load streams for multiplicands
     Ccutlass::gemm::SimplifiedGemmTraitsHelper< GemmTileTraitsHelperA_, GemmTileTraitsHelperB_, Index_ >
     Ccutlass::sqrt_est< N >
     Ccutlass::StorageType< alignment >
     Ccutlass::StorageType< 1 >
     Ccutlass::StorageType< 2 >
     Ccutlass::StorageType< 4 >
     Ccutlass::StorageType< kAlignment_ >
     Ccutlass::StorageType< sizeof(Scalar)>
     Ccutlass::Store< Scalar_, kAccessSize, Memory_, kFragmentElementType, FragmentElement_, kStride, size >
     Ccutlass::Store< double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16 >
     Ccutlass::Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2 >
     Ccutlass::Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16 >
     Ccutlass::Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4 >
     Ccutlass::Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8 >
     Ccutlass::Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     Ccutlass::gemm::GemmEpilogueTraits< OutputTile_, Accumulators_, GlobalLoadIteratorC_, GlobalTransformerC_, GlobalTransformerD_, GlobalStoreIteratorD_, SharedStoreIteratorD_, SharedStoreTransformerD_, SharedLoadStreamD_, Iterations_, Delta_, Functor_, Index_ >::StreamSharedStorageThe shared memory storage to exchange data
     Ccutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >::StrideVector
     Ccutlass::gemm::swizzleDirection
     Ccutlass::TensorRef< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
     Ccutlass::TensorRef< AType const, 2 >
     Ccutlass::TensorRef< BType const, 2 >
     Ccutlass::TensorRef< CType const, 2 >
     Ccutlass::TensorRef< DType, 2 >
     Ccutlass::TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >Specialization for rank=1 case with no internal StrideVector
     Ccutlass::TensorRefArray< Storage_, Rank_, MapFunc_, StorageRank_, Index_, LongIndex_ >
     Ccutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, ScalarA_, ScalarB_, ScalarC_, kLayout_ >Template performing matrix multiply-add operation within a thread
     Ccutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, float >Template performing matrix multiply-add operation within a thread
     Ccutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half >Template performing matrix multiply-add operation within a thread
     Ccutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int >Template performing matrix multiply-add operation within a thread
     Ccutlass::gemm::GemmSharedStoreTileAbTraits< Scalar_, Tile_, Threads_, kScalarsPerSts_ >::ThreadOffset
     Ccutlass::gemm::WmmaGemmGlobalIteratorCdTraits< Scalar_, Tile_, Threads_, kAccessSize_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmGlobalTileCdTraits< Scalar_, Tile_, Threads_, kStrideH_, kAccessSize_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmSharedLoadTileATraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, InstructionShape_, kStages_, kScalarsPerLds_, kSkew_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmSharedStoreWithSkewTileAbTraits< Scalar_, Tile_, Threads_, kScalarsPerSts_, kSkew_ >::ThreadOffset
     Ccutlass::gemm::IgemmGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmSharedLoadTileBTraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, InstructionShape_, kStages_, kScalarsPerLds_, kSkew_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmSharedLoadTileDTraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, kTileH_, kScalarsPerLds_, kSkew_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::TileTraitsWarpRake< Tile_, Threads >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::GemmSharedStoreTileDTraits< Scalar_, OutputTile_, Warps_, ThreadsPerWarp_, kScalarsPerSts_, kSkew_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::gemm::HgemmCrosswiseGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ >::ThreadOffsetComputes the thread offset in (H, W) based on thread ID
     Ccutlass::TileAllocation< Scalar_, Shape_ >Class for storing a tile in memory and accessing it through a tensor ref
     Ccutlass::TiledThreadOffset< ThreadShape >Basic thread offset function computed from a thread shape
     Ccutlass::TileIteratorBase< Traits_, Scalar_, Advance_, MemorySpace, Index_, FragmentElement_, FragmentElementType_, Skew_ >Iterator for accessing a stripmined tile in memory
     Ccutlass::TileIteratorBase< TileTraits_, TileTraits_::Scalar, Advance_, MemorySpace, Index_, TileTraits_::Scalar, FragmentElementType::kScalar, Shape< 0, 0, 0, 0 > >
     Ccutlass::TileIteratorBase< TileTraits_, TileTraits_::Scalar, IteratorAdvance::kH, MemorySpace::kGlobal, Index_ >
     Ccutlass::TileLoadStream< Iterator_, Transformer_ >Generic stream for loading and transforming fragments
     Ccutlass::TileStoreStream< Iterator_, Transformer_ >Generic stream for transforming and storing fragments
     Ccutlass::TileTraits< Tile_, Delta_, Iterations_, ThreadOffset_, AccessSize >A template defining Tile Traits Concept
     Ccutlass::TileTraitsContiguousMajor< Tile_, Threads >
     Ccutlass::TileTraitsStandard< Tile_, Threads >Chooses 'best' shape to enable warp raking along contiguous dimension if possible
     Ccutlass::TileTraitsStrideMajor< Tile_, Threads >
     Ccutlass::TileTraitsWarpRake< Tile_, Threads >Tiling in which warps rake across the contiguous dimension
     Ccutlass::PredicateVector< kPredicates_, kPredicatesPerByte_, kPredicateStart_ >::TrivialIteratorIterator that always returns true
     Ccutlass::TrivialPredicateTileAdapterAlways returns true predicate
     Ccutlass::uint4_t
     Ccutlass::platform::unique_ptr< T, Deleter >Std::unique_ptr
     Ccutlass::Vector< Scalar_, kLanes_ >
     Ccutlass::Vector< bin1_t, kLanes_ >Vector definition for 1-bit binary datatype
     Ccutlass::Vector< half, 1 >
     Ccutlass::Vector< half, kLanes_ >
     Ccutlass::Vector< int4_t, kLanes_ >Vector definition for 4-bit signed integer datatype
     Ccutlass::Vector< uint4_t, kLanes_ >Vector definition for 4-bit unsigned integer datatype
     Ccutlass::Vectorize< Element_, kLanes_ >
     Ccutlass::Vectorize< Vector< bin1_t, 32 >, kLanes_ >
     Ccutlass::Vectorize< Vector< int4_t, 8 >, kLanes_ >
     Ccutlass::Vectorize< Vector< uint4_t, 8 >, kLanes_ >
     Ccutlass::VectorTraits< T >Traits describing properties of vectors and scalar-as-vectors
     Ccutlass::VectorTraits< Vector< T, Lanes > >Partial specialization for actual cutlass::Vector
     Ccutlass::VectorTraits< Vector< T, Lanes > const >Partial specialization for actual cutlass::Vector
     Ccutlass::ZipConvert< First_, Second_ >Zips two convert operations
     Ccutlass::ZipFragment< First_, Second_ >A template defining Fragment Concept
     Ccutlass::ZipTensorRef< First_, Second_ >
     Ccutlass::ZipTileAllocation< First_, Second_ >Manages a pair of tile allocations as if they are one allocation
     Ccutlass::ZipTileIterator< First_, Second_ >Constructs an iterator from a pair of iterators
    diff --git a/docs/igemm__epilogue_8h.html b/docs/igemm__epilogue_8h.html index 9b5e5ccf0..f7332de21 100644 --- a/docs/igemm__epilogue_8h.html +++ b/docs/igemm__epilogue_8h.html @@ -82,13 +82,13 @@ $(function() {

    Defines the epilogue phase of the GEMM computation for IGEMM, supporting integer and floating-point output matrix formats. More...

    -
    #include <cutlass/convert.h>
    -#include <cutlass/fragment.h>
    -#include <cutlass/gemm/gemm_global_stream.h>
    -#include <cutlass/gemm/gemm_shared_stream.h>
    -#include <cutlass/gemm/igemm_global_tile.h>
    -#include <cutlass/reshape_tile.h>
    -#include <cutlass/tile_iterator.h>
    +

    Go to the source code of this file.

    @@ -127,7 +127,7 @@ Namespaces diff --git a/docs/igemm__epilogue_8h_source.html b/docs/igemm__epilogue_8h_source.html index bfef820ae..43f9f1583 100644 --- a/docs/igemm__epilogue_8h_source.html +++ b/docs/igemm__epilogue_8h_source.html @@ -76,67 +76,66 @@ $(function() {
    igemm_epilogue.h
    -Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    29 #pragma once
    30 
    31 #include <cutlass/convert.h>
    32 #include <cutlass/fragment.h>
    36 #include <cutlass/reshape_tile.h>
    37 #include <cutlass/tile_iterator.h>
    38 
    39 namespace cutlass {
    40 namespace gemm {
    41 
    43 
    44 template <int kElements_>
    50 
    51  // We are packing 4 floats into int32 registers so we need kElements to be multiple of 4.
    52  static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
    53 
    55  CUTLASS_DEVICE IgemmFloatToInt8Converter() {}
    56 
    58  CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
    59  transform(src, 0, dst);
    60  }
    61 
    63  template <typename Fragment_>
    64  CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
    65  // The inputs.
    66  float4 const* src_f4 = reinterpret_cast<float4 const*>(&src[0]);
    67  // The outputs.
    68  int* dst_int = reinterpret_cast<int*>(&dst[0]);
    69 
    70  // Iterate over the floats and pack them together to produce ints.
    71  for (int i = 0; i < kElements_ / 4; ++i) {
    72  // Read the float4.
    73  float4 f4 = src_f4[i];
    74 
    75  // Clamp the 4 elements of the floats to the [-128, +127] range.
    76  float x = fmaxf(-128.f, fminf(127.f, f4.x));
    77  float y = fmaxf(-128.f, fminf(127.f, f4.y));
    78  float z = fmaxf(-128.f, fminf(127.f, f4.z));
    79  float w = fmaxf(-128.f, fminf(127.f, f4.w));
    80 
    81  // Convert to integers.
    82  int ix = (int)x;
    83  int iy = (int)y;
    84  int iz = (int)z;
    85  int iw = (int)w;
    86 
    87  // Extract the lower bytes to build an int32 with 4 int8.
    88  asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(ix) : "r"(iy));
    89  asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(iz) : "r"(iw));
    90  asm volatile("prmt.b32 %0, %0, %1, 0x5410;" : "+r"(ix) : "r"(iz));
    91 
    92  // Store the int.
    93  dst_int[i] = ix;
    94  }
    95  }
    96 };
    97 
    99 
    100 template <typename InputScalar_, typename OutputFragment_>
    103 };
    104 
    105 template <int kElements_>
    106 struct IgemmGlobalStoreTransformer<float, Fragment<int8_t, kElements_> > {
    108 };
    109 
    111 
    112 template <int kElements_>
    118 
    119  // We are unpacking 4 int8s from int32.
    120  static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
    121 
    123  CUTLASS_DEVICE IgemmInt8ToFloatConverter() {}
    124 
    126  CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
    127  transform(src, 0, dst);
    128  }
    129 
    131  template <typename Fragment_>
    132  CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
    133  // The inputs.
    134  int const* src_int = reinterpret_cast<int const*>(&src[0]);
    135  // The outputs.
    136  float4* dst_f4 = reinterpret_cast<float4*>(&dst[0]);
    137 
    138  // Iterate over the int8 and unpack them together to produce floats.
    139  for (int i = 0; i < kElements_ / 4; ++i) {
    140  // Read the int.
    141  int ix, iy, iz, iw = src_int[i];
    142 
    143  // Extract the 4 bytes.
    144  asm volatile("prmt.b32 %0, 0x0, %1, 0x4440;" : "=r"(ix) : "r"(iw));
    145  asm volatile("prmt.b32 %0, 0x0, %1, 0x4441;" : "=r"(iy) : "r"(iw));
    146  asm volatile("prmt.b32 %0, 0x0, %1, 0x4442;" : "=r"(iz) : "r"(iw));
    147  asm volatile("prmt.b32 %0, 0x0, %1, 0x4443;" : "=r"(iw) : "r"(iw));
    148 
    149  // The floats.
    150  float fx, fy, fz, fw;
    151 
    152  // Convert to floats (make sure we generate I2F.F32.S8).
    153  asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fx) : "r"(ix));
    154  asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fy) : "r"(iy));
    155  asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fz) : "r"(iz));
    156  asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fw) : "r"(iw));
    157 
    158  // Store the float4.
    159  dst_f4[i] = make_float4(fx, fy, fz, fw);
    160  }
    161  }
    162 };
    163 
    165 
    166 template <typename InputFragment_, typename OutputScalar_>
    169 };
    170 
    171 template <int kElements_>
    172 struct IgemmGlobalLoadTransformer<Fragment<int8_t, kElements_>, float> {
    174 };
    175 
    177 
    178 template <typename InputScalar_, typename OutputFragment_>
    181 };
    182 
    184 
    185 template <typename IgemmConfig_, typename EpilogueFunctor_, typename Index_>
    187  : public GemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> {
    191  typedef IgemmConfig_ IgemmConfig;
    192 
    194  typedef typename Base::Scalar Scalar;
    196  typedef typename Base::Iterations Iterations;
    198  typedef typename Base::Delta Delta;
    199 
    207  typedef
    209 
    217  typedef
    219 
    232  SharedStoreFragmentD>::Transformer
    242 };
    243 
    245 
    246 template <
    248  typename IgemmConfig_,
    250  typename EpilogueFunctor_,
    252  typename Index_ = int,
    256  // The output tile.
    257  typename IgemmConfig_::OutputTile,
    258  // The accumulators.
    259  typename IgemmConfig_::Accumulators,
    260  // The global iterator for C.
    261  typename Helper_::GlobalLoadIteratorC,
    262  // The transformer for C.
    263  typename Helper_::GlobalTransformerC,
    264  // The transformer for D.
    265  typename Helper_::GlobalTransformerD,
    266  // The global iterator for D.
    267  typename Helper_::GlobalStoreIteratorD,
    268  // The iterator to store D to shared memory.
    269  typename Helper_::SharedStoreIteratorD,
    270  // The shared store transformer for D.
    271  typename Helper_::SharedStoreTransformerD,
    272  // The iterator to load D from shared memory.
    273  typename Helper_::SharedLoadIteratorD,
    274  // The iterations.
    275  typename Helper_::Iterations,
    276  // The strides between iterations.
    277  typename Helper_::Delta,
    278  // The functor to be used in the epilogue.
    279  EpilogueFunctor_,
    280  // The index.
    281  Index_> {
    283  static bool const kInt8Output =
    285 };
    286 
    288 
    289 template <typename GemmEpilogueTraits_, bool = GemmEpilogueTraits_::kInt8Output>
    290 struct IgemmEpilogue : public GemmEpilogue<GemmEpilogueTraits_> {
    293 
    295  CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
    296  typename Base::SharedStorage& shared_storage_,
    297  typename Base::Index m_,
    298  typename Base::Index n_)
    299  : Base(params_, shared_storage_, m_, n_) {}
    300 };
    301 
    303 
    304 template <typename GemmEpilogueTraits_>
    305 struct IgemmEpilogue<GemmEpilogueTraits_, true> : public GemmEpilogue<GemmEpilogueTraits_> {
    308 
    310  CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
    311  typename Base::SharedStorage& shared_storage_,
    312  typename Base::Index m_,
    313  typename Base::Index n_)
    314  : Base(params_, shared_storage_, m_, n_) {}
    315 };
    316 
    318 
    319 } // namespace gemm
    320 } // namespace cutlass
    Definition: gemm_global_tile.h:116
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    29 #pragma once
    30 
    31 #include "cutlass/convert.h"
    32 #include "cutlass/fragment.h"
    36 #include "cutlass/reshape_tile.h"
    37 #include "cutlass/tile_iterator.h"
    38 
    39 namespace cutlass {
    40 namespace gemm {
    41 
    43 
    44 template <int kElements_>
    50 
    51  // We are packing 4 floats into int32 registers so we need kElements to be multiple of 4.
    52  static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
    53 
    55  CUTLASS_DEVICE IgemmFloatToInt8Converter() {}
    56 
    58  CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
    59  transform(src, 0, dst);
    60  }
    61 
    63  template <typename Fragment_>
    64  CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
    65  // The inputs.
    66  float4 const* src_f4 = reinterpret_cast<float4 const*>(&src[0]);
    67  // The outputs.
    68  int* dst_int = reinterpret_cast<int*>(&dst[0]);
    69 
    70  // Iterate over the floats and pack them together to produce ints.
    71  for (int i = 0; i < kElements_ / 4; ++i) {
    72  // Read the float4.
    73  float4 f4 = src_f4[i];
    74 
    75  // Clamp the 4 elements of the floats to the [-128, +127] range.
    76  float x = fmaxf(-128.f, fminf(127.f, f4.x));
    77  float y = fmaxf(-128.f, fminf(127.f, f4.y));
    78  float z = fmaxf(-128.f, fminf(127.f, f4.z));
    79  float w = fmaxf(-128.f, fminf(127.f, f4.w));
    80 
    81  // Convert to integers.
    82  int ix = (int)x;
    83  int iy = (int)y;
    84  int iz = (int)z;
    85  int iw = (int)w;
    86 
    87  // Extract the lower bytes to build an int32 with 4 int8.
    88  asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(ix) : "r"(iy));
    89  asm volatile("prmt.b32 %0, %0, %1, 0x1140;" : "+r"(iz) : "r"(iw));
    90  asm volatile("prmt.b32 %0, %0, %1, 0x5410;" : "+r"(ix) : "r"(iz));
    91 
    92  // Store the int.
    93  dst_int[i] = ix;
    94  }
    95  }
    96 };
    97 
    99 
    100 template <typename InputScalar_, typename OutputFragment_>
    103 };
    104 
    105 template <int kElements_>
    106 struct IgemmGlobalStoreTransformer<float, Fragment<int8_t, kElements_> > {
    108 };
    109 
    111 
    112 template <int kElements_>
    118 
    119  // We are unpacking 4 int8s from int32.
    120  static_assert(kElements_ % 4 == 0, "kElements must be multiple of 4");
    121 
    123  CUTLASS_DEVICE IgemmInt8ToFloatConverter() {}
    124 
    126  CUTLASS_DEVICE void transform(InputFragment const& src, OutputFragment& dst) {
    127  transform(src, 0, dst);
    128  }
    129 
    131  template <typename Fragment_>
    132  CUTLASS_DEVICE void transform(Fragment_ const& src, int offset, OutputFragment& dst) {
    133  // The inputs.
    134  int const* src_int = reinterpret_cast<int const*>(&src[0]);
    135  // The outputs.
    136  float4* dst_f4 = reinterpret_cast<float4*>(&dst[0]);
    137 
    138  // Iterate over the int8 and unpack them together to produce floats.
    139  for (int i = 0; i < kElements_ / 4; ++i) {
    140  // Read the int.
    141  int ix, iy, iz, iw = src_int[i];
    142 
    143  // Extract the 4 bytes.
    144  asm volatile("prmt.b32 %0, 0x0, %1, 0x4440;" : "=r"(ix) : "r"(iw));
    145  asm volatile("prmt.b32 %0, 0x0, %1, 0x4441;" : "=r"(iy) : "r"(iw));
    146  asm volatile("prmt.b32 %0, 0x0, %1, 0x4442;" : "=r"(iz) : "r"(iw));
    147  asm volatile("prmt.b32 %0, 0x0, %1, 0x4443;" : "=r"(iw) : "r"(iw));
    148 
    149  // The floats.
    150  float fx, fy, fz, fw;
    151 
    152  // Convert to floats (make sure we generate I2F.F32.S8).
    153  asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fx) : "r"(ix));
    154  asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fy) : "r"(iy));
    155  asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fz) : "r"(iz));
    156  asm volatile("cvt.rn.f32.s8 %0, %1;" : "=f"(fw) : "r"(iw));
    157 
    158  // Store the float4.
    159  dst_f4[i] = make_float4(fx, fy, fz, fw);
    160  }
    161  }
    162 };
    163 
    165 
    166 template <typename InputFragment_, typename OutputScalar_>
    169 };
    170 
    171 template <int kElements_>
    172 struct IgemmGlobalLoadTransformer<Fragment<int8_t, kElements_>, float> {
    174 };
    175 
    177 
    178 template <typename InputScalar_, typename OutputFragment_>
    181 };
    182 
    184 
    185 template <typename IgemmConfig_, typename EpilogueFunctor_, typename Index_>
    187  : public GemmEpilogueTraitsHelper<IgemmConfig_, EpilogueFunctor_, Index_> {
    191  typedef IgemmConfig_ IgemmConfig;
    192 
    194  typedef typename Base::Scalar Scalar;
    196  typedef typename Base::Iterations Iterations;
    198  typedef typename Base::Delta Delta;
    199 
    207  typedef
    209 
    217  typedef
    219 
    232  SharedStoreFragmentD>::Transformer
    242 };
    243 
    245 
    246 template <
    248  typename IgemmConfig_,
    250  typename EpilogueFunctor_,
    252  typename Index_ = int,
    256  // The output tile.
    257  typename IgemmConfig_::OutputTile,
    258  // The accumulators.
    259  typename IgemmConfig_::Accumulators,
    260  // The global iterator for C.
    261  typename Helper_::GlobalLoadIteratorC,
    262  // The transformer for C.
    263  typename Helper_::GlobalTransformerC,
    264  // The transformer for D.
    265  typename Helper_::GlobalTransformerD,
    266  // The global iterator for D.
    267  typename Helper_::GlobalStoreIteratorD,
    268  // The iterator to store D to shared memory.
    269  typename Helper_::SharedStoreIteratorD,
    270  // The shared store transformer for D.
    271  typename Helper_::SharedStoreTransformerD,
    272  // The stream to load D from shared memory.
    273  typename Helper_::SharedLoadStreamD,
    274  // The iterations.
    275  typename Helper_::Iterations,
    276  // The strides between iterations.
    277  typename Helper_::Delta,
    278  // The functor to be used in the epilogue.
    279  EpilogueFunctor_,
    280  // The index.
    281  Index_> {
    283  static bool const kInt8Output =
    285 };
    286 
    288 
    289 template <typename GemmEpilogueTraits_, bool = GemmEpilogueTraits_::kInt8Output>
    290 struct IgemmEpilogue : public GemmEpilogue<GemmEpilogueTraits_> {
    293 
    295  CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
    296  typename Base::SharedStorage& shared_storage_,
    297  Coord<3> const& _problem_size)
    298  : Base(params_, shared_storage_, _problem_size) {}
    299 };
    300 
    302 
    303 template <typename GemmEpilogueTraits_>
    304 struct IgemmEpilogue<GemmEpilogueTraits_, true> : public GemmEpilogue<GemmEpilogueTraits_> {
    307 
    309  CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const& params_,
    310  typename Base::SharedStorage& shared_storage_,
    311  Coord<3> const& _problem_size)
    312  : Base(params_, shared_storage_, _problem_size) {}
    313 };
    314 
    316 
    317 } // namespace gemm
    318 } // namespace cutlass
    Definition: gemm_global_tile.h:120
    Definition: igemm_epilogue.h:255
    -
    Definition: load_store.h:42
    +
    Definition: load_store.h:41
    Base::Delta Delta
    The iterations strides.
    Definition: igemm_epilogue.h:198
    -
    Base::Fragment Fragment
    Fragment definition.
    Definition: tile_iterator.h:682
    Base::SharedStoreTileTraits SharedStoreTileTraits
    The traits class for the shared iterator to store D to shared memory.
    Definition: igemm_epilogue.h:221
    IgemmGlobalStoreTransformer< Scalar, GlobalFragmentD >::Transformer GlobalTransformerD
    The transformer from accumulators to shared memory fragments.
    Definition: igemm_epilogue.h:218
    Definition: convert.h:33
    Base::SharedLoadTileTraits SharedLoadTileTraits
    The traits class for the shared iterator to load D from shared memory.
    Definition: igemm_epilogue.h:235
    TileLoadIterator< SharedLoadTileTraits, typename SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorD
    The shared iterator to load D from shared memory.
    Definition: igemm_epilogue.h:241
    -
    Definition: gemm_epilogue_traits.h:171
    +
    Definition: gemm_epilogue_traits.h:186
    GemmEpilogue< GemmEpilogueTraits_ > Base
    The base class.
    Definition: igemm_epilogue.h:292
    -
    Traits::Params Params
    The params.
    Definition: gemm_epilogue.h:57
    -
    Definition: gemm_epilogue.h:53
    +
    Traits::Params Params
    The params.
    Definition: gemm_epilogue.h:46
    +
    Definition: gemm_epilogue.h:42
    Definition: igemm_epilogue.h:167
    -
    std::is_same (false specialization)
    Definition: platform.h:412
    +
    std::is_same (false specialization)
    Definition: platform.h:420
    Defines the Tile Traits concept and iterators for loading and storing to tiles efficiently.
    CUTLASS_DEVICE IgemmInt8ToFloatConverter()
    Ctor.
    Definition: igemm_epilogue.h:123
    SharedStoreIteratorD::Fragment SharedStoreFragmentD
    The fragment that needs to be passed to that store iterator.
    Definition: igemm_epilogue.h:229
    -
    EpilogueFunctor_::Scalar Scalar
    The scalar.
    Definition: gemm_epilogue_traits.h:173
    +
    EpilogueFunctor_::Scalar Scalar
    The scalar.
    Definition: gemm_epilogue_traits.h:188
    Definition: igemm_epilogue.h:186
    -
    Definition: load_store.h:43
    +
    Definition: load_store.h:42
    Fragment< int8_t, kElements_ > InputFragment
    The input fragment.
    Definition: igemm_epilogue.h:115
    +
    Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
    The fragment.
    Definition: tile_iterator.h:196
    Definition: igemm_epilogue.h:290
    Definition: igemm_epilogue.h:45
    CUTLASS_DEVICE void transform(Fragment_ const &src, int offset, OutputFragment &dst)
    Transform a fragment.
    Definition: igemm_epilogue.h:64
    -
    Traits::SharedStorage SharedStorage
    The shared storage.
    Definition: gemm_epilogue.h:59
    +
    Traits::SharedStorage SharedStorage
    The shared storage.
    Definition: gemm_epilogue.h:48
    A template defining Fragment Concept.
    Definition: fragment.h:99
    -
    Definition: tile_iterator.h:62
    +
    Definition: tile_iterator.h:65
    CUTLASS_DEVICE void transform(InputFragment const &src, OutputFragment &dst)
    Transform a fragment.
    Definition: igemm_epilogue.h:126
    Base::Scalar Scalar
    The scalar type of the epilogue.
    Definition: igemm_epilogue.h:194
    +
    CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const &params_, typename Base::SharedStorage &shared_storage_, Coord< 3 > const &_problem_size)
    Ctor.
    Definition: igemm_epilogue.h:295
    GlobalLoadIteratorC::Fragment GlobalFragmentC
    The fragment that needs to be produced by the load iterator.
    Definition: igemm_epilogue.h:205
    +
    Base::Fragment Fragment
    Fragment definition.
    Definition: tile_iterator.h:901
    CUTLASS_DEVICE void transform(InputFragment const &src, OutputFragment &dst)
    Transform a fragment.
    Definition: igemm_epilogue.h:58
    Fragment< int8_t, kElements_ > OutputFragment
    The output fragment.
    Definition: igemm_epilogue.h:49
    GemmGlobalIteratorCd< GlobalStoreTileTraits > GlobalStoreIteratorD
    The iterator to store to shared memory.
    Definition: igemm_epilogue.h:213
    IgemmSharedStoreTransformer< typename IgemmConfig::Accumulators::Element, SharedStoreFragmentD >::Transformer SharedStoreTransformerD
    The transformer from accumulators to shared memory fragments.
    Definition: igemm_epilogue.h:233
    static bool const kInt8Output
    Do we output in int8?
    Definition: igemm_epilogue.h:283
    -
    An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
    Definition: tile_iterator.h:302
    +
    An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
    Definition: tile_iterator.h:399
    Convert< Fragment< InputScalar_, OutputFragment_::kElements >, OutputFragment_ > Transformer
    Definition: igemm_epilogue.h:180
    -
    GemmEpilogue< GemmEpilogueTraits_ > Base
    The base class.
    Definition: igemm_epilogue.h:307
    +
    GemmEpilogue< GemmEpilogueTraits_ > Base
    The base class.
    Definition: igemm_epilogue.h:306
    Defines a type for restructuring a tile.
    Base::GlobalLoadTileTraits GlobalLoadTileTraits
    The traits class for the iterator.
    Definition: igemm_epilogue.h:201
    Fragment< float, kElements_ > OutputFragment
    The output fragment.
    Definition: igemm_epilogue.h:117
    GemmEpilogueTraitsHelper< IgemmConfig_, EpilogueFunctor_, Index_ > Base
    The base class.
    Definition: igemm_epilogue.h:189
    -
    CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const &params_, typename Base::SharedStorage &shared_storage_, typename Base::Index m_, typename Base::Index n_)
    Ctor.
    Definition: igemm_epilogue.h:295
    -
    Definition: gemm_shared_tile.h:335
    -
    Traits::Index Index
    The index.
    Definition: gemm_epilogue.h:93
    +
    Definition: gemm_shared_tile.h:339
    GlobalStoreIteratorD::Fragment GlobalFragmentD
    The fragment that needs to be passed to that store iterator.
    Definition: igemm_epilogue.h:215
    GemmGlobalIteratorCd< GlobalLoadTileTraits > GlobalLoadIteratorC
    The iterator to store to shared memory.
    Definition: igemm_epilogue.h:203
    -
    #define static_assert(__e, __m)
    Definition: platform.h:145
    +
    #define static_assert(__e, __m)
    Definition: platform.h:153
    IgemmConfig_ IgemmConfig
    The config.
    Definition: igemm_epilogue.h:191
    -
    CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const &params_, typename Base::SharedStorage &shared_storage_, typename Base::Index m_, typename Base::Index n_)
    Ctor.
    Definition: igemm_epilogue.h:310
    A Shape implementing Layout Concept describing the dimensions of a cube.
    Definition: shape.h:64
    CUTLASS_DEVICE IgemmFloatToInt8Converter()
    Ctor.
    Definition: igemm_epilogue.h:55
    Element_ Element
    The element.
    Definition: fragment.h:108
    Fragment< float, kElements_ > InputFragment
    The input fragment.
    Definition: igemm_epilogue.h:47
    +
    Definition: gemm_epilogue_traits.h:70
    -
    Definition: gemm_global_tile.h:348
    +
    Definition: gemm_global_tile.h:396
    Definition: igemm_epilogue.h:179
    Implements efficient loading of the thread block-level tile from global memory and storing to shared ...
    -
    Fragment< FragmentElement, ShapeCount< Iterations >::kCount *kAccessSize > Fragment
    The fragment.
    Definition: tile_iterator.h:154
    Definition: convert.h:38
    IgemmFloatToInt8Converter< kElements_ > Transformer
    Definition: igemm_epilogue.h:107
    Base::Iterations Iterations
    The iterations.
    Definition: igemm_epilogue.h:196
    @@ -144,7 +143,7 @@ $(function() {
    Base::GlobalStoreTileTraits GlobalStoreTileTraits
    The traits class for the iterator.
    Definition: igemm_epilogue.h:211
    Convert< InputFragment_, Fragment< OutputScalar_, InputFragment_::kElements > > Transformer
    Definition: igemm_epilogue.h:168
    Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
    -
    platform::remove_const< Scalar_ >::type Scalar
    The scalar.
    Definition: gemm_shared_tile.h:266
    +
    platform::remove_const< Scalar_ >::type Scalar
    The scalar.
    Definition: gemm_shared_tile.h:272
    CUTLASS_DEVICE void transform(Fragment_ const &src, int offset, OutputFragment &dst)
    Transform a fragment.
    Definition: igemm_epilogue.h:132
    Convert< Fragment< InputScalar_, OutputFragment_::kElements >, OutputFragment_ > Transformer
    Definition: igemm_epilogue.h:102
    Defines abstractions for managing loading and storing fragments to shared memory in the efficient GEM...
    @@ -153,14 +152,15 @@ $(function() {
    IgemmInt8ToFloatConverter< kElements_ > Transformer
    Definition: igemm_epilogue.h:173
    Defines conversion operations among Fragments of different base type.
    Definition: igemm_epilogue.h:113
    -
    platform::remove_const< Scalar_ >::type Scalar
    The scalar.
    Definition: gemm_shared_tile.h:337
    +
    platform::remove_const< Scalar_ >::type Scalar
    The scalar.
    Definition: gemm_shared_tile.h:341
    +
    CUTLASS_DEVICE IgemmEpilogue(typename Base::Params const &params_, typename Base::SharedStorage &shared_storage_, Coord< 3 > const &_problem_size)
    Ctor.
    Definition: igemm_epilogue.h:309
    Implements tile iterators to partition the thread block tile into 2D subtiles and efficiently load ea...
    -
    Definition: gemm_shared_tile.h:264
    -
    An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
    Definition: tile_iterator.h:620
    +
    Definition: gemm_shared_tile.h:270
    +
    An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
    Definition: tile_iterator.h:836
    diff --git a/docs/igemm__global__tile_8h.html b/docs/igemm__global__tile_8h.html index d6a680168..4b5ee6d7c 100644 --- a/docs/igemm__global__tile_8h.html +++ b/docs/igemm__global__tile_8h.html @@ -82,18 +82,20 @@ $(function() {

    Implements tile iterators to partition the thread block tile into 2D subtiles and efficiently load each. Applies permute transformation to construct 'interleaved K-strided' data layout in which 4-element dot products from the same K index are arranged in consecutive locations within shared memory. More...

    -
    - + - - + + + +

    Classes

    struct  cutlass::gemm::IgemmContiguousGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ >
    struct  cutlass::gemm::IgemmGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ >
     
    struct  cutlass::gemm::IgemmContiguousGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ >::ThreadOffset
     Computes the thread offset in (H, W) based on thread ID. More...
    struct  cutlass::gemm::IgemmGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ >::ThreadOffset
     Computes the thread offset in (H, W) based on thread ID. More...
     
    struct  cutlass::gemm::IgemmGlobalIteratorAb< TileTraits_, Index_ >
     
    diff --git a/docs/igemm__global__tile_8h_source.html b/docs/igemm__global__tile_8h_source.html index df086169d..04428a68e 100644 --- a/docs/igemm__global__tile_8h_source.html +++ b/docs/igemm__global__tile_8h_source.html @@ -76,33 +76,46 @@ $(function() {
    igemm_global_tile.h
    -Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    33 #pragma once
    34 
    35 #include <cutlass/coord.h>
    37 #include <cutlass/matrix_traits.h>
    38 
    39 namespace cutlass {
    40 namespace gemm {
    41 
    43 
    44 template <GemmOperand::Kind kOperand_,
    45  MatrixLayout::Kind kLayout_,
    46  typename Scalar_,
    47  typename Tile_,
    48  typename Threads_,
    49  int kAccessSize_>
    51  // Which GEMM operand?
    52  kOperand_,
    53  // The layout.
    54  kLayout_,
    55  // The scalar.
    56  Scalar_,
    57  // The tile.
    58  Tile_,
    59  // The threads.
    60  Threads_,
    61  // The number of scalars per LDG/STG.
    62  kAccessSize_> {
    66  typedef typename Base::Threads Threads;
    70  typedef Shape<Base::Tile::kH / Base::Threads::kH / 4,
    71  4,
    72  Base::Tile::kW / Base::Threads::kW,
    73  Base::Tile::kC / Base::kAccessSize>
    75 
    77  struct ThreadOffset {
    79  Coord<4> operator()() const {
    80  int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
    81  int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
    82 
    83  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
    84  }
    85  };
    86 
    87  public:
    90 };
    91 
    93 
    94 } // namespace gemm
    95 } // namespace cutlass
    Computes the thread offset in (H, W) based on thread ID.
    Definition: igemm_global_tile.h:77
    -
    Definition: convert.h:33
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    33 #pragma once
    34 
    35 #include "cutlass/coord.h"
    37 #include "cutlass/matrix_traits.h"
    38 
    39 namespace cutlass {
    40 namespace gemm {
    41 
    43 
    44 template <GemmOperand::Kind kOperand_,
    45  MatrixLayout::Kind kLayout_,
    46  typename Scalar_,
    47  typename Tile_,
    48  typename Threads_,
    49  int kAccessSize_>
    51  // Which GEMM operand?
    52  kOperand_,
    53  // The layout.
    54  kLayout_,
    55  // The scalar.
    56  Scalar_,
    57  // The tile.
    58  Tile_,
    59  // The threads.
    60  Threads_,
    61  // The number of scalars per LDG/STG.
    62  kAccessSize_> {
    66  typedef typename Base::Threads Threads;
    70  typedef Shape<Base::VectorizedTile::kH / Base::Threads::kH / 4,
    71  4,
    72  Base::VectorizedTile::kW / Base::Threads::kW,
    73  Base::VectorizedTile::kC / Base::kAccessSize>
    75 
    77  struct ThreadOffset {
    79  Coord<4> operator()() const {
    80  int thread_offset_h = threadIdx.x / Threads::kW * ThreadsDelta::kH;
    81  int thread_offset_w = threadIdx.x % Threads::kW * ThreadsDelta::kW;
    82 
    83  return make_Coord(0, thread_offset_h, thread_offset_w, 0);
    84  }
    85  };
    86 
    87  public:
    90 };
    91 
    93 
    94 template <typename TileTraits_, typename Index_ = int>
    95 struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb<TileTraits_, Index_> {
    99  typedef typename TileTraits_::ThreadOffset ThreadOffset;
    100 
    102  CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const& _params,
    103  const Coord<3>& bounds,
    104  const Coord<3>& threadblock_offset,
    105  ThreadOffset thread_offset_func = ThreadOffset())
    106  : Base(_params, bounds, threadblock_offset, thread_offset_func), mask_(0xffffffff) {
    107  // The number of elements read in a single iteration.
    108  int const kBlock = TileTraits_::Tile::kW;
    109  // The residue.
    110  int const kResidue = (int)(bounds[1] % kBlock);
    111 
    112  // Compute the number of elements that are valid.
    113  int const left = kResidue - Base::thread_offset[2];
    114  if (left > 0 && left < 4) {
    115  mask_ = (1u << (8 * left)) - 1u;
    116  }
    117  }
    118 
    119  CUTLASS_DEVICE void load_element(
    120  typename Base::AccessType& value, int d, int h, int w, int c) const {
    121  Base::load_element(value, d, h, w, c);
    122  reinterpret_cast<uint32_t&>(value) &= mask_;
    123  }
    124 
    126  uint32_t mask_;
    127 };
    128 
    130 
    131 } // namespace gemm
    132 } // namespace cutlass
    Definition: convert.h:33
    +
    Base::Threads Threads
    The threads.
    Definition: igemm_global_tile.h:66
    +
    Computes the thread offset in (H, W) based on thread ID.
    Definition: igemm_global_tile.h:77
    Defines iterators for efficiently loading and storing to global memory.
    Definition: gemm_global_tile.h:70
    A Coord is a coordinate of arbitrary rank into a tensor or matrix.
    -
    CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
    Helper to make a 2-element coordinate.
    Definition: coord.h:241
    -
    Shape< Base::Threads::kH *4, 1, Base::Threads::kW, Base::kAccessSize > Delta
    The strides in each dimension between different loads/stores.
    Definition: igemm_global_tile.h:68
    +
    Shape< Base::VectorizedTile::kH/Base::Threads::kH/4, 4, Base::VectorizedTile::kW/Base::Threads::kW, Base::VectorizedTile::kC/Base::kAccessSize > Iterations
    The number of iterations needed to load/store the tile.
    Definition: igemm_global_tile.h:74
    +
    CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
    Helper to make a 2-element coordinate.
    Definition: coord.h:318
    +
    CUTLASS_HOST_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
    Loads a single fragment element from memory.
    Definition: gemm_global_tile.h:292
    +
    CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
    Definition: igemm_global_tile.h:79
    +
    Definition: gemm_global_tile.h:163
    static int const kH
    The height of the cube.
    Definition: shape.h:68
    -
    GemmGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ > Base
    The base class.
    Definition: igemm_global_tile.h:64
    -
    Shape< Base::Tile::kH/Base::Threads::kH/4, 4, Base::Tile::kW/Base::Threads::kW, Base::Tile::kC/Base::kAccessSize > Iterations
    The number of iterations needed to load/store the tile.
    Definition: igemm_global_tile.h:74
    +
    An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
    Definition: tile_iterator.h:399
    +
    Definition: igemm_global_tile.h:50
    +
    CUTLASS_DEVICE void load_element(typename Base::AccessType &value, int d, int h, int w, int c) const
    Definition: igemm_global_tile.h:119
    +
    GemmGlobalIteratorAb< TileTraits_, Index_ > Base
    The base class.
    Definition: igemm_global_tile.h:97
    +
    Definition: igemm_global_tile.h:95
    #define CUTLASS_HOST_DEVICE
    Definition: cutlass.h:46
    -
    Definition: igemm_global_tile.h:50
    +
    Definition: vector.h:62
    A Shape implementing Layout Concept describing the dimensions of a cube.
    Definition: shape.h:64
    +
    TileTraits_::ThreadOffset ThreadOffset
    The functor to compute the thread offset.
    Definition: igemm_global_tile.h:99
    +
    uint32_t mask_
    The mask to clean up the values.
    Definition: igemm_global_tile.h:126
    +
    ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
    The threads shape.
    Definition: gemm_global_tile.h:88
    +
    CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const &_params, const Coord< 3 > &bounds, const Coord< 3 > &threadblock_offset, ThreadOffset thread_offset_func=ThreadOffset())
    Constructor.
    Definition: igemm_global_tile.h:102
    +
    Shape< 1, 4, Base::VectorizedTile::kC > ThreadsDelta
    The threads strides.
    Definition: igemm_global_tile.h:89
    +
    TileTraits_::ThreadOffset ThreadOffset
    The thread offset.
    Definition: gemm_global_tile.h:192
    static int const kW
    The width of the cube.
    Definition: shape.h:70
    -
    Kind
    Definition: matrix_traits.h:36
    +
    Parameters.
    Definition: tile_iterator.h:491
    +
    Kind
    Enumeration defining fundamental contiguous layouts.
    Definition: matrix_traits.h:159
    static int const kAccessSize
    The number of scalars per LDG/STG.
    Definition: gemm_global_tile.h:80
    -
    Kind
    Definition: matrix_traits.h:43
    -
    ReshapeThreads< Tile, Threads_ >::Threads Threads
    The threads shape.
    Definition: gemm_global_tile.h:87
    +
    Kind
    Definition: matrix_traits.h:357
    +
    Shape< Base::Threads::kH *4, 1, Base::Threads::kW, Base::kAccessSize > Delta
    The strides in each dimension between different loads/stores.
    Definition: igemm_global_tile.h:68
    Defines properties of matrices used to denote layout and operands to GEMM kernels.
    -
    Shape< 1, 4, Base::Tile::kC > ThreadsDelta
    The threads strides.
    Definition: igemm_global_tile.h:89
    -
    CUTLASS_HOST_DEVICE Coord< 4 > operator()() const
    Definition: igemm_global_tile.h:79
    -
    Base::Threads Threads
    The threads.
    Definition: igemm_global_tile.h:66
    +
    Coord< 4 > thread_offset
    Offset of an individual lane from the start of the tile.
    Definition: gemm_global_tile.h:237
    +
    GemmGlobalTileTraits< kOperand_, kLayout_, Scalar_, Tile_, Threads_, kAccessSize_ > Base
    The base class.
    Definition: igemm_global_tile.h:64
    diff --git a/docs/igemm__multiply__add_8h.html b/docs/igemm__multiply__add_8h.html index 266cb5f16..d67e57b8d 100644 --- a/docs/igemm__multiply__add_8h.html +++ b/docs/igemm__multiply__add_8h.html @@ -82,15 +82,15 @@ $(function() {

    Implements matrix multiply accumulate operation of 8-bit integer data using DP4A instruction. More...

    -

    @@ -108,7 +110,7 @@ Namespaces

    - - + +

    Classes

    struct  cutlass::gemm::ThreadMultiplyAdd< AccumulatorsPerThread_, ThreadsPerWarp_, int8_t, int8_t, int >
     Template performing matrix multiply-add operation within a thread. More...
    struct  cutlass::gemm::ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int >
     Template performing matrix multiply-add operation within a thread. More...
     
    diff --git a/docs/igemm__multiply__add_8h_source.html b/docs/igemm__multiply__add_8h_source.html index 414c2ce17..b67129ef4 100644 --- a/docs/igemm__multiply__add_8h_source.html +++ b/docs/igemm__multiply__add_8h_source.html @@ -76,29 +76,30 @@ $(function() {
    igemm_multiply_add.h
    -Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    29 #pragma once
    30 
    31 #include <cutlass/fragment.h>
    32 
    34 
    35 namespace cutlass {
    36 namespace gemm {
    37 
    39 
    41 template <typename AccumulatorsPerThread_, typename ThreadsPerWarp_>
    42 struct ThreadMultiplyAdd<AccumulatorsPerThread_, ThreadsPerWarp_, int8_t, int8_t, int> {
    46  typedef AccumulatorsPerThread_ AccumulatorsPerThread;
    48  typedef ThreadsPerWarp_ ThreadsPerWarp;
    52  typedef int8_t ScalarA;
    56  typedef int8_t ScalarB;
    60  typedef int ScalarC;
    63 
    65  CUTLASS_DEVICE ThreadMultiplyAdd() {}
    66 
    68  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
    69  FragmentB const& b,
    70  Accumulators const& c,
    71  Accumulators& d) {
    72  // The inputs.
    73  int const* a_int = reinterpret_cast<int const*>(&a[0]);
    74  int const* b_int = reinterpret_cast<int const*>(&b[0]);
    75 
    76  for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
    77  for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
    78  asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
    79  : "=r"(d[j * AccumulatorsPerThread::kW + i])
    80  : "r"(a_int[i]), "r"(b_int[j]), "r"(c[j * AccumulatorsPerThread::kW + i]));
    81  }
    82  }
    83  }
    84 };
    85 
    87 
    88 } // namespace gemm
    89 } // namespace cutlass
    -
    Definition: convert.h:33
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    29 #pragma once
    30 
    31 #include "cutlass/fragment.h"
    32 
    34 
    35 namespace cutlass {
    36 namespace gemm {
    37 
    39 
    41 template <typename ThreadGemmShape_, typename ThreadsPerWarp_>
    42 struct ThreadMultiplyAdd<ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int> {
    46  typedef ThreadGemmShape_ ThreadGemmShape;
    50  typedef ThreadsPerWarp_ ThreadsPerWarp;
    54  typedef int8_t ScalarA;
    58  typedef int8_t ScalarB;
    62  typedef int ScalarC;
    65 
    67  CUTLASS_DEVICE ThreadMultiplyAdd() {}
    68 
    70  CUTLASS_DEVICE void multiply_add(FragmentA const& a,
    71  FragmentB const& b,
    72  Accumulators const& c,
    73  Accumulators& d) {
    74  // The inputs.
    75  int const* a_int = reinterpret_cast<int const*>(&a[0]);
    76  int const* b_int = reinterpret_cast<int const*>(&b[0]);
    77 
    78  for (int j = 0; j < AccumulatorsPerThread::kH; ++j) {
    79  for (int i = 0; i < AccumulatorsPerThread::kW; ++i) {
    80  asm volatile("dp4a.s32.s32 %0, %1, %2, %3;"
    81  : "=r"(d[j * AccumulatorsPerThread::kW + i])
    82  : "r"(a_int[i]), "r"(b_int[j]), "r"(c[j * AccumulatorsPerThread::kW + i]));
    83  }
    84  }
    85  }
    86 };
    87 
    89 
    90 } // namespace gemm
    91 } // namespace cutlass
    Definition: convert.h:33
    +
    Fragment< ScalarA, AccumulatorsPerThread::kW *4 > FragmentA
    The fragment for A.
    Definition: igemm_multiply_add.h:56
    Shape< A_::kD *B_::kD, A_::kH *B_::kH, A_::kW *B_::kW, A_::kC *B_::kC > Shape
    Definition: shape.h:119
    A template defining Fragment Concept.
    Definition: fragment.h:99
    Template implementing matrix multiply-add operations on fragments.
    -
    Fragment< ScalarC, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW > Accumulators
    The accumulators.
    Definition: igemm_multiply_add.h:62
    -
    ShapeMul< AccumulatorsPerThread, ThreadsPerWarp >::Shape AccumulatorsPerWarp
    The number of accumulators per warp.
    Definition: igemm_multiply_add.h:50
    -
    Fragment< ScalarB, AccumulatorsPerThread::kH *4 > FragmentB
    The fragment for B.
    Definition: igemm_multiply_add.h:58
    - -
    Shape< 4, 1, 1 > InstructionShape
    The shape of the instruction.
    Definition: igemm_multiply_add.h:44
    -
    ThreadsPerWarp_ ThreadsPerWarp
    The number of threads per warp.
    Definition: igemm_multiply_add.h:48
    -
    AccumulatorsPerThread_ AccumulatorsPerThread
    The number of accumulators per thread.
    Definition: igemm_multiply_add.h:46
    +
    CUTLASS_DEVICE ThreadMultiplyAdd()
    Ctor.
    Definition: igemm_multiply_add.h:67
    +
    CUTLASS_DEVICE void multiply_add(FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d)
    Multiply : d = a*b + c.
    Definition: igemm_multiply_add.h:70
    +
    int ScalarC
    The type for C and D.
    Definition: igemm_multiply_add.h:62
    +
    Shape< 4, 1, 1 > InstructionShape
    The shape of the instruction.
    Definition: igemm_multiply_add.h:44
    +
    ThreadsPerWarp_ ThreadsPerWarp
    The number of threads per warp.
    Definition: igemm_multiply_add.h:50
    +
    ThreadGemmShape_ ThreadGemmShape
    Shape of the thread-level GEMM (K-by-N-by-M)
    Definition: igemm_multiply_add.h:46
    +
    Fragment< ScalarC, AccumulatorsPerThread::kH *AccumulatorsPerThread::kW > Accumulators
    The accumulators.
    Definition: igemm_multiply_add.h:64
    A Shape implementing Layout Concept describing the dimensions of a cube.
    Definition: shape.h:64
    -
    Template performing matrix multiply-add operation within a thread.
    Definition: thread_multiply_add.h:43
    -
    Fragment< ScalarA, AccumulatorsPerThread::kW *4 > FragmentA
    The fragment for A.
    Definition: igemm_multiply_add.h:54
    - - -
    CUTLASS_DEVICE void multiply_add(FragmentA const &a, FragmentB const &b, Accumulators const &c, Accumulators &d)
    Multiply : d = a*b + c.
    Definition: igemm_multiply_add.h:68
    +
    ShapeMul< ThreadGemmShape, ThreadsPerWarp >::Shape AccumulatorsPerWarp
    The number of accumulators per warp.
    Definition: igemm_multiply_add.h:52
    +
    Template performing matrix multiply-add operation within a thread.
    Definition: thread_multiply_add.h:44
    +
    ThreadGemmShape AccumulatorsPerThread
    Aliased for compatibility. Will be removed in CUTLASS v2.0.
    Definition: igemm_multiply_add.h:48
    +
    Fragment< ScalarB, AccumulatorsPerThread::kH *4 > FragmentB
    The fragment for B.
    Definition: igemm_multiply_add.h:60
    +
    Defines Fragment, a statically-sized array for storing parts of matrices within a thread&#39;s registers...
    +
    diff --git a/docs/igemm__swizzle_8h.html b/docs/igemm__swizzle_8h.html index a631d215c..c87855219 100644 --- a/docs/igemm__swizzle_8h.html +++ b/docs/igemm__swizzle_8h.html @@ -82,7 +82,7 @@ $(function() {

    Transposes a fragment of data containing packed 8-bit integer elements. More...

    -

    @@ -103,7 +103,7 @@ Namespaces

    @@ -101,7 +101,7 @@ Namespaces diff --git a/docs/igemm__swizzle_8h_source.html b/docs/igemm__swizzle_8h_source.html index 939908301..015b5f9af 100644 --- a/docs/igemm__swizzle_8h_source.html +++ b/docs/igemm__swizzle_8h_source.html @@ -76,14 +76,14 @@ $(function() {
    igemm_swizzle.h
    -Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    28 #pragma once
    29 
    30 #include <cutlass/fragment.h>
    31 
    32 namespace cutlass {
    33 namespace gemm {
    34 
    36 
    37 template <typename GlobalIterator_>
    38 struct IgemmSwizzle {
    40  typedef GlobalIterator_ GlobalIterator;
    42  typedef typename GlobalIterator::Fragment Fragment;
    44  typedef typename GlobalIterator::FragmentShape FragmentShape;
    45 
    50 
    53 
    55  static_assert(FragmentShape::kH % 4 == 0 && ShapeCount<FragmentShape>::kWc % 4 == 0,
    56  "Not multiple of 4");
    57 
    59  CUTLASS_DEVICE IgemmSwizzle() {}
    60 
    62  CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
    63  // Expose src/dst as int arrays.
    64  int const* src_int = reinterpret_cast<int const*>(&src[0]);
    65  int* dst_int = reinterpret_cast<int*>(&dst[0]);
    66 
    67  // Transpose the data.
    68  for (int d = 0; d < FragmentShape::kD; ++d) {
    69  for (int h = 0; h < FragmentShape::kH / 4; ++h) {
    70  for (int w = 0; w < ShapeCount<FragmentShape>::kWc / 4; ++w) {
    71  int const i0 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
    72  (4 * h + 0) * (ShapeCount<FragmentShape>::kWc / 4) + w;
    73  int const i1 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
    74  (4 * h + 1) * (ShapeCount<FragmentShape>::kWc / 4) + w;
    75  int const i2 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
    76  (4 * h + 2) * (ShapeCount<FragmentShape>::kWc / 4) + w;
    77  int const i3 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
    78  (4 * h + 3) * (ShapeCount<FragmentShape>::kWc / 4) + w;
    79 
    80  int a0 = src_int[i0];
    81  int a1 = src_int[i1];
    82  int a2 = src_int[i2];
    83  int a3 = src_int[i3];
    84 
    85  int b0, b1, b2, b3, c0;
    86  asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(b0) : "r"(a0), "r"(a1));
    87  asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(c0) : "r"(a2), "r"(a3));
    88  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(b0), "r"(c0));
    89 
    90  asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(b1) : "r"(a0), "r"(a1));
    91  asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(c0) : "r"(a2), "r"(a3));
    92  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b1) : "r"(b1), "r"(c0));
    93 
    94  asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(b2) : "r"(a0), "r"(a1));
    95  asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(c0) : "r"(a2), "r"(a3));
    96  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b2) : "r"(b2), "r"(c0));
    97 
    98  asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(b3) : "r"(a0), "r"(a1));
    99  asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(c0) : "r"(a2), "r"(a3));
    100  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b3) : "r"(b3), "r"(c0));
    101 
    102  dst_int[i0] = b0;
    103  dst_int[i1] = b1;
    104  dst_int[i2] = b2;
    105  dst_int[i3] = b3;
    106  }
    107  }
    108  }
    109  }
    110 };
    111 
    113 
    114 } // namespace gemm
    115 } // namespace cutlass
    Definition: convert.h:33
    -
    std::is_same (false specialization)
    Definition: platform.h:412
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    28 #pragma once
    29 
    30 #include "cutlass/fragment.h"
    31 
    32 namespace cutlass {
    33 namespace gemm {
    34 
    36 
    37 template <typename GlobalIterator_>
    38 struct IgemmSwizzle {
    40  typedef GlobalIterator_ GlobalIterator;
    42  typedef typename GlobalIterator::Fragment Fragment;
    44  typedef typename GlobalIterator::FragmentShape FragmentShape;
    45 
    50 
    53 
    55  static_assert(FragmentShape::kH % 4 == 0 && ShapeCount<FragmentShape>::kWc % 4 == 0,
    56  "Not multiple of 4");
    57 
    59  CUTLASS_DEVICE IgemmSwizzle() {}
    60 
    62  CUTLASS_DEVICE void transform(Fragment const& src, Fragment& dst) {
    63  // Expose src/dst as int arrays.
    64  int const* src_int = reinterpret_cast<int const*>(&src[0]);
    65  int* dst_int = reinterpret_cast<int*>(&dst[0]);
    66 
    67  // Transpose the data.
    68  for (int d = 0; d < FragmentShape::kD; ++d) {
    69  for (int h = 0; h < FragmentShape::kH / 4; ++h) {
    70  for (int w = 0; w < ShapeCount<FragmentShape>::kWc / 4; ++w) {
    71  int const i0 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
    72  (4 * h + 0) * (ShapeCount<FragmentShape>::kWc / 4) + w;
    73  int const i1 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
    74  (4 * h + 1) * (ShapeCount<FragmentShape>::kWc / 4) + w;
    75  int const i2 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
    76  (4 * h + 2) * (ShapeCount<FragmentShape>::kWc / 4) + w;
    77  int const i3 = d * (ShapeCount<FragmentShape>::kHwc / 4) +
    78  (4 * h + 3) * (ShapeCount<FragmentShape>::kWc / 4) + w;
    79 
    80  int a0 = src_int[i0];
    81  int a1 = src_int[i1];
    82  int a2 = src_int[i2];
    83  int a3 = src_int[i3];
    84 
    85  // // DEBUG.
    86  // if (threadIdx.x == 0) {
    87  // printf("a=0x%08x 0x%08x 0x%08x 0x%08x\n", a0, a1, a2, a3);
    88  // }
    89 
    90  int b0, b1, b2, b3, c0;
    91  asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(b0) : "r"(a0), "r"(a1));
    92  asm volatile("prmt.b32 %0, %1, %2, 0x0040;" : "=r"(c0) : "r"(a2), "r"(a3));
    93  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b0) : "r"(b0), "r"(c0));
    94 
    95  asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(b1) : "r"(a0), "r"(a1));
    96  asm volatile("prmt.b32 %0, %1, %2, 0x0051;" : "=r"(c0) : "r"(a2), "r"(a3));
    97  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b1) : "r"(b1), "r"(c0));
    98 
    99  asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(b2) : "r"(a0), "r"(a1));
    100  asm volatile("prmt.b32 %0, %1, %2, 0x0062;" : "=r"(c0) : "r"(a2), "r"(a3));
    101  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b2) : "r"(b2), "r"(c0));
    102 
    103  asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(b3) : "r"(a0), "r"(a1));
    104  asm volatile("prmt.b32 %0, %1, %2, 0x0073;" : "=r"(c0) : "r"(a2), "r"(a3));
    105  asm volatile("prmt.b32 %0, %1, %2, 0x5410;" : "=r"(b3) : "r"(b3), "r"(c0));
    106 
    107  // // DEBUG.
    108  // if (threadIdx.x == 0) {
    109  // printf("b=0x%08x 0x%08x 0x%08x 0x%08x\n", b0, b1, b2, b3);
    110  // }
    111 
    112  dst_int[i0] = b0;
    113  dst_int[i1] = b1;
    114  dst_int[i2] = b2;
    115  dst_int[i3] = b3;
    116  }
    117  }
    118  }
    119  }
    120 };
    121 
    123 
    124 } // namespace gemm
    125 } // namespace cutlass
    Definition: convert.h:33
    +
    std::is_same (false specialization)
    Definition: platform.h:420
    GlobalIterator::FragmentShape FragmentShape
    The shape of the source fragment.
    Definition: igemm_swizzle.h:44
    Definition: igemm_swizzle.h:38
    GlobalIterator_ GlobalIterator
    The global iterator.
    Definition: igemm_swizzle.h:40
    CUTLASS_DEVICE void transform(Fragment const &src, Fragment &dst)
    Transform a fragment.
    Definition: igemm_swizzle.h:62
    Fragment OutputFragment
    The destination fragment.
    Definition: igemm_swizzle.h:49
    -
    #define static_assert(__e, __m)
    Definition: platform.h:145
    +
    #define static_assert(__e, __m)
    Definition: platform.h:153
    Fragment InputFragment
    The source fragment.
    Definition: igemm_swizzle.h:47
    GlobalIterator::Fragment Fragment
    The source fragment.
    Definition: igemm_swizzle.h:42
    CUTLASS_DEVICE IgemmSwizzle()
    The src/dst must be int8 fragments.
    Definition: igemm_swizzle.h:59
    @@ -92,7 +92,7 @@ $(function() {
    diff --git a/docs/igemm__traits_8h.html b/docs/igemm__traits_8h.html index 32d14d876..897687ee2 100644 --- a/docs/igemm__traits_8h.html +++ b/docs/igemm__traits_8h.html @@ -82,34 +82,38 @@ $(function() {

    Defies structural properties of mixed-precision integer GEMM. Multiplicands are assumed to be packed 8bit integers, accumulators are assumed to be 32b signed integers, and output formats vary. More...

    -
    - + - + - + - + - + - + + + + + @@ -123,13 +127,13 @@ Classes - + - +

    Classes

    struct  cutlass::gemm::IgemmConfig< OutputTile_, ScalarD_, AccumulatorsPerThread_ >
    struct  cutlass::gemm::IgemmConfig< OutputTile_, ScalarD_, ThreadGemmShape_ >
     
    struct  cutlass::gemm::IgemmConfig< OutputTile_, int8_t, AccumulatorsPerThread_ >
    struct  cutlass::gemm::IgemmConfig< OutputTile_, int8_t, ThreadGemmShape_ >
     
    struct  cutlass::gemm::IgemmTileTraitsHelperA< kLayout_, GemmConfig_ >
    struct  cutlass::gemm::IgemmTileTraitsHelperA< kLayout_, GemmConfig_, Index_ >
     
    struct  cutlass::gemm::IgemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_ >
    struct  cutlass::gemm::IgemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_, Index_ >
     
    struct  cutlass::gemm::IgemmTileTraitsHelperB< kLayout_, GemmConfig_ >
    struct  cutlass::gemm::IgemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_, Index_ >
     
    struct  cutlass::gemm::IgemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_ >
    struct  cutlass::gemm::IgemmTileTraitsHelperB< kLayout_, GemmConfig_, Index_ >
     
    struct  cutlass::gemm::IgemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_, Index_ >
     
    struct  cutlass::gemm::IgemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_, Index_ >
     
    struct  cutlass::gemm::IgemmTransformerA< kLayout_, Iterator_ >
     
     
    struct  cutlass::gemm::IgemmTransformerB< MatrixLayout::kRowMajor, Iterator_ >
     
    struct  cutlass::gemm::IgemmTraitsHelper< kLayoutA_, kLayoutB_, OutputTile_, ScalarD_, EpilogueFunctor_, AccumulatorsPerThread_, Index_ >
    struct  cutlass::gemm::IgemmTraitsHelper< kLayoutA_, kLayoutB_, OutputTile_, ScalarD_, EpilogueFunctor_, ThreadGemmShape_, Index_ >
     
    struct  cutlass::gemm::IgemmEpilogueScalar< ScalarD_ >
     
    struct  cutlass::gemm::IgemmEpilogueScalar< int >
     
    struct  cutlass::gemm::IgemmTraits< kLayoutA_, kLayoutB_, OutputTile_, ScalarD_, EpilogueFunctor_, AccumulatorsPerThread_, Index_, Helper_ >
    struct  cutlass::gemm::IgemmTraits< kLayoutA_, kLayoutB_, OutputTile_, ScalarD_, EpilogueFunctor_, ThreadGemmShape_, Index_, Helper_ >
     
    diff --git a/docs/igemm__traits_8h_source.html b/docs/igemm__traits_8h_source.html index ecdd4f1df..e1fa87e40 100644 --- a/docs/igemm__traits_8h_source.html +++ b/docs/igemm__traits_8h_source.html @@ -76,89 +76,108 @@ $(function() {
    igemm_traits.h
    -Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    30 #pragma once
    31 
    32 #include <cutlass/convert.h>
    33 #include <cutlass/gemm/gemm.h>
    43 #include <cutlass/reshape_tile.h>
    44 
    45 namespace cutlass {
    46 namespace gemm {
    47 
    49 
    50 template <
    52  typename OutputTile_,
    54  typename ScalarD_,
    56  typename AccumulatorsPerThread_>
    58  : public GemmConfig<
    60  int8_t,
    62  int8_t,
    64  ScalarD_,
    66  ScalarD_,
    68  OutputTile_,
    70  ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, int8_t, int8_t, int>,
    72  4,
    74  4,
    76  16,
    78  4,
    80  4,
    82  16,
    84  1,
    86  4,
    88  1,
    90  2> {};
    91 
    93 
    94 template <typename OutputTile_, typename AccumulatorsPerThread_>
    95 struct IgemmConfig<OutputTile_, int8_t, AccumulatorsPerThread_>
    96  : public GemmConfig<
    98  int8_t,
    100  int8_t,
    102  int8_t,
    104  int8_t,
    106  OutputTile_,
    108  ThreadMultiplyAdd<AccumulatorsPerThread_, Shape<1, 4, 8>, int8_t, int8_t, int>,
    110  4,
    112  4,
    114  16,
    116  4,
    118  4,
    120  16,
    122  4,
    124  4,
    126  4,
    128  2> {};
    129 
    131 
    132 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
    133 struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
    134 
    136 
    137 template <typename GemmConfig_>
    138 struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_>
    139  : public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
    142 
    144  static int const kScalarsPerStsA = 16;
    145 
    149  // The layout.
    151  // The pointer is float const.
    152  int8_t const,
    153  // The tile has size KxM in GEMM's terminology.
    155  // The threads are distributed as warps x 32 (the traits may reorganize).
    157  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
    158  4>
    160 
    163  // The pointer is float.
    164  int8_t,
    165  // The tile has size KxM in GEMM's terminology.
    166  Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
    167  // The threads are distributed as warps x 32 (the traits may reorganize).
    168  typename GlobalTileTraits::Threads,
    169  // The number of scalars per STS (STS.32 or STS.128, etc).
    170  kScalarsPerStsA>
    172 };
    173 
    175 
    176 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_>
    177 struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
    178 
    180 
    181 template <typename GemmConfig_>
    182 struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_>
    183  : public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
    186 
    188  static int const kScalarsPerStsB = 16;
    189 
    193  // The layout.
    195  // The pointer is float const.
    196  int8_t const,
    197  // The tile has size KxM in GEMM's terminology.
    199  // The threads are distributed as warps x 32 (the traits may reorganize).
    201  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
    202  4>
    204 
    207  // The pointer is float.
    208  int8_t,
    209  // The tile has size KxM in GEMM's terminology.
    210  Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
    211  // The threads are distributed as warps x 32 (the traits may reorganize).
    212  typename GlobalTileTraits::Threads,
    213  // The number of scalars per STS (STS.32 or STS.128, etc).
    214  kScalarsPerStsB>
    216 };
    217 
    219 
    220 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
    222 
    223 template <typename Iterator_>
    224 struct IgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
    226 };
    227 
    228 template <typename Iterator_>
    229 struct IgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
    231 };
    232 
    234 
    235 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
    237 
    238 template <typename Iterator_>
    239 struct IgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
    241 };
    242 
    243 template <typename Iterator_>
    244 struct IgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
    246 };
    247 
    249 
    250 template <
    252  MatrixLayout::Kind kLayoutA_,
    254  MatrixLayout::Kind kLayoutB_,
    256  typename OutputTile_,
    258  typename ScalarD_,
    260  typename EpilogueFunctor_,
    262  typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
    264  typename Index_ = int>
    272 
    277  typedef typename IgemmTransformerA<GemmTileTraitsHelperA::kLayout,
    280  typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
    281  typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
    288 
    292  // The default transformer for B.
    293  typedef typename IgemmTransformerB<GemmTileTraitsHelperB::kLayout,
    296  typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
    297  typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
    304 
    306  typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
    307  typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
    315  typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
    316  typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
    323 
    328 
    331 };
    332 
    334 
    335 template <typename ScalarD_>
    337  typedef float Scalar;
    338 };
    339 
    340 template <>
    341 struct IgemmEpilogueScalar<int> {
    342  typedef int Scalar;
    343 };
    344 
    346 
    347 template <
    349  MatrixLayout::Kind kLayoutA_,
    351  MatrixLayout::Kind kLayoutB_,
    353  typename OutputTile_ = Shape<32, 128, 128>,
    355  typename ScalarD_ = int,
    359  typename AccumulatorsPerThread_ = Shape<32, 8, 8>,
    361  typename Index_ = int,
    363  typename Helper_ = IgemmTraitsHelper<kLayoutA_,
    364  kLayoutB_,
    365  OutputTile_,
    366  ScalarD_,
    367  EpilogueFunctor_,
    368  AccumulatorsPerThread_,
    369  Index_> >
    370 struct IgemmTraits : public GemmTraits<
    371  // The config.
    372  typename Helper_::GemmConfig,
    373  // The stream to load A from global memory to shared memory.
    374  typename Helper_::GlobalLoadStreamA,
    375  // The stream to load B from global memory to shared memory.
    376  typename Helper_::GlobalLoadStreamB,
    377  // The stream to load A from shared memory.
    378  typename Helper_::SharedLoadStreamA,
    379  // The stream to load B from shared memory.
    380  typename Helper_::SharedLoadStreamB,
    381  // The epilogue.
    382  typename Helper_::Epilogue,
    383  // The block swizzle to reorganize the grid.
    384  IdentityBlockSwizzle,
    385  // The index.
    386  Index_,
    387  // The tool used to clear accumulators.
    388  typename Helper_::ClearAccumulators> {};
    389 
    391 
    392 } // namespace gemm
    393 } // namespace cutlass
    Definition: load_store.h:42
    -
    TileLoadIterator< typename GemmTileTraitsHelperB::SharedLoadTileTraits, typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorB
    The iterator to load B from shared memory.
    Definition: igemm_traits.h:319
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    30 #pragma once
    31 
    32 #include "cutlass/convert.h"
    33 #include "cutlass/gemm/gemm.h"
    43 #include "cutlass/reshape_tile.h"
    44 
    45 namespace cutlass {
    46 namespace gemm {
    47 
    49 
    50 template <
    52  typename OutputTile_,
    54  typename ScalarD_,
    56  typename ThreadGemmShape_>
    57 struct IgemmConfig : public GemmConfig<
    59  int8_t,
    61  int8_t,
    63  ScalarD_,
    65  ScalarD_,
    67  OutputTile_,
    69  ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, int8_t, int8_t, int>,
    71  4,
    73  4,
    75  16,
    77  4,
    79  4,
    81  16,
    83  1,
    85  4,
    87  1,
    89  2,
    91  false,
    93  false,
    95  false> {};
    96 
    98 
    99 template <typename OutputTile_, typename ThreadGemmShape_>
    100 struct IgemmConfig<OutputTile_, int8_t, ThreadGemmShape_>
    101  : public GemmConfig<
    103  int8_t,
    105  int8_t,
    107  int8_t,
    109  int8_t,
    111  OutputTile_,
    113  ThreadMultiplyAdd<ThreadGemmShape_, Shape<1, 4, 8>, int8_t, int8_t, int>,
    115  4,
    117  4,
    119  16,
    121  4,
    123  4,
    125  16,
    127  4,
    129  4,
    131  4,
    133  2,
    135  false,
    137  true,
    139  false> {};
    140 
    142 
    143 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
    144 struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA<kLayout_, GemmConfig_> {};
    145 
    147 
    148 template <typename GemmConfig_, typename Index_>
    149 struct IgemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_, Index_>
    150  : public GemmTileTraitsHelperA<MatrixLayout::kColumnMajor, GemmConfig_> {
    153 
    155  static int const kScalarsPerStsA = 16;
    156 
    158  typedef IgemmGlobalTileTraits<
    160  // The layout.
    162  // The pointer is float const.
    163  int8_t const,
    164  // The tile has size KxM in GEMM's terminology.
    166  // The threads are distributed as warps x 32 (the traits may reorganize).
    168  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
    169  GemmConfig_::kScalarsPerLdgA>
    171 
    174 
    177  // The pointer is float.
    178  int8_t,
    179  // The tile has size KxM in GEMM's terminology.
    180  Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
    181  // The threads are distributed as warps x 32 (the traits may reorganize).
    182  typename GlobalTileTraits::Threads,
    183  // The number of scalars per STS (STS.32 or STS.128, etc).
    184  kScalarsPerStsA>
    186 };
    187 
    189 
    190 template <typename GemmConfig_, typename Index_>
    191 struct IgemmTileTraitsHelperA<MatrixLayout::kRowMajor, GemmConfig_, Index_> {
    194 
    196  typedef int8_t Scalar;
    198  typedef int8_t MultiplyAddScalar;
    199 
    201  static int const kScalarsPerStsA = 16;
    202 
    204  typedef IgemmGlobalTileTraits<
    206  // The layout.
    208  // The pointer is float const.
    209  int8_t const,
    210  // The tile has size NxK in GEMM's terminology.
    212  // The threads are distributed as warps x 32 (the traits may reorganize).
    214  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
    215  GemmConfig_::kScalarsPerLdgA>
    217 
    220 
    223  // The pointer is int8.
    224  int8_t,
    225  // The tile has size KxN in GEMM's terminology.
    226  Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kW * 4>,
    227  // The threads are distributed as (threads / K) x K (the traits may reorganize).
    228  typename GlobalTileTraits::Threads,
    229  // The number of scalars per STS.
    230  kScalarsPerStsA,
    231  // The skew to avoid bank conflicts added in the tile W dimension.
    232  16>
    234 
    237  // The pointer is float const.
    238  int8_t const,
    239  // The output tile size.
    240  typename GemmConfig_::OutputTile,
    241  // The number of warps.
    242  typename GemmConfig_::Warps,
    243  // The number of threads per warp.
    244  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
    245  // The shape of the FMA instruction.
    246  typename GemmConfig_::InstructionShape,
    247  // The number of stages.
    248  GemmConfig_::kStages,
    249  // The number of scalars per LDS.
    250  16,
    251  // The skew.
    252  SharedStoreTileTraits::kSkew>
    254 };
    255 
    257 
    258 template <enum MatrixLayout::Kind kLayout_, typename GemmConfig_, typename Index_>
    259 struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB<kLayout_, GemmConfig_> {};
    260 
    262 
    263 template <typename GemmConfig_, typename Index_>
    264 struct IgemmTileTraitsHelperB<MatrixLayout::kColumnMajor, GemmConfig_, Index_> {
    267 
    269  typedef int8_t Scalar;
    271  typedef int8_t MultiplyAddScalar;
    272 
    274  static int const kScalarsPerStsB = 16;
    275 
    277  typedef IgemmGlobalTileTraits<
    279  // The layout.
    281  // The pointer is float const.
    282  int8_t const,
    283  // The tile has size NxK in GEMM's terminology.
    285  // The threads are distributed as warps x 32 (the traits may reorganize).
    287  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
    288  GemmConfig_::kScalarsPerLdgB>
    290 
    293 
    296  // The pointer is int8.
    297  int8_t,
    298  // The tile has size KxN in GEMM's terminology.
    299  Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
    300  // The threads are distributed as (threads / K) x K (the traits may reorganize).
    301  typename GlobalTileTraits::Threads,
    302  // The number of scalars per STS.
    303  kScalarsPerStsB,
    304  // The skew to avoid bank conflicts added in the tile W dimension.
    305  16>
    307 
    310  // The pointer is float const.
    311  int8_t const,
    312  // The output tile size.
    313  typename GemmConfig_::OutputTile,
    314  // The number of warps.
    315  typename GemmConfig_::Warps,
    316  // The number of threads per warp.
    317  typename GemmConfig_::MultiplyAdd::ThreadsPerWarp,
    318  // The shape of the FMA instruction.
    319  typename GemmConfig_::InstructionShape,
    320  // The number of stages.
    321  GemmConfig_::kStages,
    322  // The number of scalars per LDS.
    323  16,
    324  // The skew.
    325  SharedStoreTileTraits::kSkew>
    327 };
    328 
    330 
    331 template <typename GemmConfig_, typename Index_>
    332 struct IgemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_, Index_>
    333  : public GemmTileTraitsHelperB<MatrixLayout::kRowMajor, GemmConfig_> {
    336 
    338  static int const kScalarsPerStsB = 16;
    339 
    341  typedef IgemmGlobalTileTraits<
    343  // The layout.
    345  // The pointer is float const.
    346  int8_t const,
    347  // The tile has size KxM in GEMM's terminology.
    349  // The threads are distributed as warps x 32 (the traits may reorganize).
    351  // The number of scalars per LDG (LDG.32 or LDG.128, etc).
    352  GemmConfig_::kScalarsPerLdgB>
    354 
    357 
    360  // The pointer is float.
    361  int8_t,
    362  // The tile has size KxM in GEMM's terminology.
    363  Shape<GemmConfig_::kStages, GemmConfig_::OutputTile::kD / 4, GemmConfig_::OutputTile::kH * 4>,
    364  // The threads are distributed as warps x 32 (the traits may reorganize).
    365  typename GlobalTileTraits::Threads,
    366  // The number of scalars per STS (STS.32 or STS.128, etc).
    367  kScalarsPerStsB>
    369 };
    370 
    372 
    373 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
    375 
    376 template <typename Iterator_>
    377 struct IgemmTransformerA<MatrixLayout::kRowMajor, Iterator_> {
    379 };
    380 
    381 template <typename Iterator_>
    382 struct IgemmTransformerA<MatrixLayout::kColumnMajor, Iterator_> {
    384 };
    385 
    387 
    388 template <enum MatrixLayout::Kind kLayout_, typename Iterator_>
    390 
    391 template <typename Iterator_>
    392 struct IgemmTransformerB<MatrixLayout::kColumnMajor, Iterator_> {
    394 };
    395 
    396 template <typename Iterator_>
    397 struct IgemmTransformerB<MatrixLayout::kRowMajor, Iterator_> {
    399 };
    400 
    402 
    403 template <
    405  MatrixLayout::Kind kLayoutA_,
    407  MatrixLayout::Kind kLayoutB_,
    409  typename OutputTile_,
    411  typename ScalarD_,
    413  typename EpilogueFunctor_,
    415  typename ThreadGemmShape_ = Shape<32, 8, 8>,
    417  typename Index_ = int>
    425 
    427  typedef typename GemmTileTraitsHelperA::GlobalLoadIterator GlobalLoadIteratorA;
    429  typedef typename IgemmTransformerA<GemmTileTraitsHelperA::kLayout,
    432  typedef TileStoreIterator<typename GemmTileTraitsHelperA::SharedStoreTileTraits,
    433  typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar,
    443 
    445  typedef typename GemmTileTraitsHelperB::GlobalLoadIterator GlobalLoadIteratorB;
    446  // The default transformer for B.
    447  typedef typename IgemmTransformerB<GemmTileTraitsHelperB::kLayout,
    450  typedef TileStoreIterator<typename GemmTileTraitsHelperB::SharedStoreTileTraits,
    451  typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar,
    461 
    463  typedef TileLoadIterator<typename GemmTileTraitsHelperA::SharedLoadTileTraits,
    464  typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar,
    472  typedef TileLoadIterator<typename GemmTileTraitsHelperB::SharedLoadTileTraits,
    473  typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar,
    480 
    485 
    488 };
    489 
    491 
    492 template <typename ScalarD_>
    494  typedef float Scalar;
    495 };
    496 
    497 template <>
    498 struct IgemmEpilogueScalar<int> {
    499  typedef int Scalar;
    500 };
    501 
    503 
    504 template <
    506  MatrixLayout::Kind kLayoutA_,
    508  MatrixLayout::Kind kLayoutB_,
    510  typename OutputTile_ = Shape<32, 128, 128>,
    512  typename ScalarD_ = int,
    516  typename ThreadGemmShape_ = Shape<32, 8, 8>,
    518  typename Index_ = int,
    520  typename Helper_ = IgemmTraitsHelper<kLayoutA_,
    521  kLayoutB_,
    522  OutputTile_,
    523  ScalarD_,
    524  EpilogueFunctor_,
    525  ThreadGemmShape_,
    526  Index_> >
    527 struct IgemmTraits : public GemmTraits<
    528  // The config.
    529  typename Helper_::GemmConfig,
    530  // The stream to load A from global memory to shared memory.
    531  typename Helper_::GlobalLoadStreamA,
    532  // The stream to load B from global memory to shared memory.
    533  typename Helper_::GlobalLoadStreamB,
    534  // The stream to load A from shared memory.
    535  typename Helper_::SharedLoadStreamA,
    536  // The stream to load B from shared memory.
    537  typename Helper_::SharedLoadStreamB,
    538  // The epilogue.
    539  typename Helper_::Epilogue,
    540  // The block swizzle to reorganize the grid.
    541  IdentityBlockSwizzle,
    542  // The index.
    543  Index_,
    544  // The tool used to clear accumulators.
    545  typename Helper_::ClearAccumulators> {};
    546 
    548 
    549 } // namespace gemm
    550 } // namespace cutlass
    IgemmTransformerB< GemmTileTraitsHelperB::kLayout, GlobalLoadIteratorB >::Transformer GlobalTransformerB
    Definition: igemm_traits.h:448
    +
    Definition: load_store.h:41
    +
    GemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_ > Base
    The base config.
    Definition: igemm_traits.h:335
    Definition: convert.h:33
    -
    IgemmSwizzle< Iterator_ > Transformer
    Definition: igemm_traits.h:230
    +
    Definition: gemm_shared_tile.h:128
    +
    Base::Threads Threads
    The threads.
    Definition: igemm_global_tile.h:66
    +
    IgemmTileTraitsHelperB< kLayoutB_, GemmConfig, Index_ > GemmTileTraitsHelperB
    The GEMM config for B.
    Definition: igemm_traits.h:424
    + +
    IgemmSwizzle< Iterator_ > Transformer
    Definition: igemm_traits.h:383
    Defines iterators for efficiently loading and storing to global memory.
    -
    GemmGlobalIteratorAb< typename GemmTileTraitsHelperA::GlobalTileTraits, Index_ > GlobalLoadIteratorA
    The iterator to load A from global memory.
    Definition: igemm_traits.h:275
    Transposes a fragment of data containing packed 8-bit integer elements.
    -
    Copy< typename Iterator_::Fragment > Transformer
    Definition: igemm_traits.h:240
    +
    Copy< typename Iterator_::Fragment > Transformer
    Definition: igemm_traits.h:393
    +
    GemmSharedStoreWithSkewTileAbTraits< int8_t, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/4, GemmConfig_::OutputTile::kW *4 >, typename GlobalTileTraits::Threads, kScalarsPerStsA, 16 > SharedStoreTileTraits
    The traits class to build the iterator to store data to shared memory for A^N.
    Definition: igemm_traits.h:233
    +
    IgemmGlobalTileTraits< GemmOperand::kB, MatrixLayout::kColumnMajor, int8_t const, Shape< 1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
    The traits class to build the iterator to load data from global memory for B^T.
    Definition: igemm_traits.h:289
    Defines structural properties of complete GEMM computation.
    -
    GlobalLoadStream< GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB > GlobalLoadStreamB
    The stream to load B from global memory to shared memory.
    Definition: igemm_traits.h:303
    -
    Definition: igemm_traits.h:133
    -
    TileStoreIterator< typename GemmTileTraitsHelperB::SharedStoreTileTraits, typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorB
    The iterator to store B to shared memory.
    Definition: igemm_traits.h:300
    -
    IgemmTransformerB< GemmTileTraitsHelperB::kLayout, GlobalLoadIteratorB >::Transformer GlobalTransformerB
    Definition: igemm_traits.h:294
    +
    IgemmGlobalIteratorAb< GlobalTileTraits, Index_ > GlobalLoadIterator
    The global load iterator.
    Definition: igemm_traits.h:219
    +
    Definition: igemm_traits.h:144
    Definition: igemm_epilogue.h:290
    -
    IgemmContiguousGlobalTileTraits< GemmOperand::kB, MatrixLayout::kRowMajor, int8_t const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, 4 > GlobalTileTraits
    The traits class to build the iterator to load data from global memory for B^T.
    Definition: igemm_traits.h:203
    Definition: convert.h:69
    -
    GemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_ > Base
    The base config.
    Definition: igemm_traits.h:141
    -
    IgemmConfig< OutputTile_, ScalarD_, AccumulatorsPerThread_ > GemmConfig
    The IGEMM config.
    Definition: igemm_traits.h:267
    +
    IgemmGlobalTileTraits< GemmOperand::kB, MatrixLayout::kRowMajor, int8_t const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kH >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, GemmConfig_::kScalarsPerLdgB > GlobalTileTraits
    The traits class to build the iterator to load data from global memory for B^T.
    Definition: igemm_traits.h:353
    Definition: gemm_shared_tile.h:38
    -
    Definition: tile_iterator.h:62
    +
    Definition: tile_iterator.h:65
    +
    int8_t MultiplyAddScalar
    The scalar stored in shared memory.
    Definition: igemm_traits.h:198
    +
    GemmTileTraitsHelperB::GlobalLoadIterator GlobalLoadIteratorB
    The iterator to load B from global memory.
    Definition: igemm_traits.h:445
    Implements matrix multiply accumulate operation of 8-bit integer data using DP4A instruction.
    -
    Definition: gemm_global_tile.h:159
    -
    GemmSharedStoreTileAbTraits< int8_t, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/4, GemmConfig_::OutputTile::kH *4 >, typename GlobalTileTraits::Threads, kScalarsPerStsB > SharedStoreTileTraits
    The traits class to build the iterator to store data to shared memory for B^N.
    Definition: igemm_traits.h:215
    +
    Definition: gemm_shared_tile.h:200
    +
    TileStoreIterator< typename GemmTileTraitsHelperB::SharedStoreTileTraits, typename GemmTileTraitsHelperB::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorB
    The iterator to store B to shared memory.
    Definition: igemm_traits.h:454
    +
    GemmSharedLoadTileBTraits< int8_t const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, 16, SharedStoreTileTraits::kSkew > SharedLoadTileTraits
    The traits class to build the iterator to load from shared memory for B^N.
    Definition: igemm_traits.h:326
    +
    Definition: gemm_global_tile.h:163
    +
    int8_t MultiplyAddScalar
    The scalar stored in shared memory.
    Definition: igemm_traits.h:271
    Implements the epilogue phase of the GEMM kernel that efficiently updates global memory with the comp...
    -
    Definition: gemm_global_stream.h:161
    -
    Definition: gemm_traits.h:273
    -
    GemmGlobalIteratorAb< typename GemmTileTraitsHelperB::GlobalTileTraits, Index_ > GlobalLoadIteratorB
    The iterator to load B from global memory.
    Definition: igemm_traits.h:291
    -
    IgemmContiguousGlobalTileTraits< GemmOperand::kA, MatrixLayout::kColumnMajor, int8_t const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, 4 > GlobalTileTraits
    The traits class to build the iterator to load data from global memory for A^N.
    Definition: igemm_traits.h:159
    -
    int Scalar
    Definition: igemm_traits.h:342
    -
    IgemmSwizzle< Iterator_ > Transformer
    Definition: igemm_traits.h:245
    -
    Describes layouts of matrices.
    Definition: matrix_traits.h:35
    -
    IgemmTileTraitsHelperB< kLayoutB_, GemmConfig > GemmTileTraitsHelperB
    The GEMM config for B.
    Definition: igemm_traits.h:271
    +
    IgemmGlobalTileTraits< GemmOperand::kA, MatrixLayout::kRowMajor, int8_t const, Shape< 1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
    The traits class to build the iterator to load data from global memory for A^T.
    Definition: igemm_traits.h:216
    +
    Definition: gemm_global_stream.h:52
    +
    Definition: gemm_traits.h:191
    +
    IgemmEpilogue< IgemmEpilogueTraits< GemmConfig, EpilogueFunctor_ > > Epilogue
    The epilogue.
    Definition: igemm_traits.h:487
    +
    int Scalar
    Definition: igemm_traits.h:499
    +
    IgemmSwizzle< Iterator_ > Transformer
    Definition: igemm_traits.h:398
    +
    Defines data layouts of various matrix formats usable by TensorRef and other classes.
    Definition: matrix_traits.h:156
    +
    GemmSharedStoreTileAbTraits< int8_t, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/4, GemmConfig_::OutputTile::kW *4 >, typename GlobalTileTraits::Threads, kScalarsPerStsA > SharedStoreTileTraits
    The traits class to build the iterator to store data to shared memory for A^N.
    Definition: igemm_traits.h:185
    Definition: igemm_swizzle.h:38
    -
    Definition: igemm_traits.h:177
    -
    Definition: igemm_traits.h:265
    -
    An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
    Definition: tile_iterator.h:302
    -
    GlobalLoadStream< GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA > GlobalLoadStreamA
    The stream to load A from global memory to shared memory.
    Definition: igemm_traits.h:287
    -
    SharedLoadStream< SharedLoadIteratorB, Copy< typename SharedLoadIteratorB::Fragment > > SharedLoadStreamB
    The stream to load B from shared memory.
    Definition: igemm_traits.h:322
    +
    Definition: igemm_traits.h:259
    +
    Definition: igemm_traits.h:418
    +
    An iterator implementing Tile Load Iterator Concept for loading a tile from memory.
    Definition: tile_iterator.h:399
    +
    IgemmTransformerA< GemmTileTraitsHelperA::kLayout, GlobalLoadIteratorA >::Transformer GlobalTransformerA
    The default transformer for A.
    Definition: igemm_traits.h:430
    Defines iterators for efficiently loading and storing tiles to and from shared memory.
    -
    Definition: matrix_traits.h:36
    -
    IgemmTileTraitsHelperA< kLayoutA_, GemmConfig > GemmTileTraitsHelperA
    The GEMM config for A.
    Definition: igemm_traits.h:269
    -
    Definition: gemm_shared_stream.h:44
    +
    GlobalLoadStream< GemmOperand::kB, GlobalLoadIteratorB, SharedStoreIteratorB, GlobalTransformerB > GlobalLoadStreamB
    The stream to load B from global memory to shared memory.
    Definition: igemm_traits.h:460
    +
    Definition: matrix_traits.h:159
    +
    Definition: gemm_shared_stream.h:45
    +
    Definition: igemm_global_tile.h:50
    Defines a type for restructuring a tile.
    -
    TileLoadIterator< typename GemmTileTraitsHelperA::SharedLoadTileTraits, typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorA
    The iterator to load A from shared memory.
    Definition: igemm_traits.h:310
    -
    ClearAccumulators< typename MultiplyAdd::ScalarC > ClearAccumulators
    The object to clear accumulators.
    Definition: igemm_traits.h:327
    -
    Definition: gemm_traits.h:79
    -
    Definition: gemm_traits.h:137
    -
    Definition: matrix_traits.h:43
    +
    GemmTileTraitsHelperA::GlobalLoadIterator GlobalLoadIteratorA
    The iterator to load A from global memory.
    Definition: igemm_traits.h:427
    +
    Definition: gemm_config.h:76
    +
    TileStoreIterator< typename GemmTileTraitsHelperA::SharedStoreTileTraits, typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorA
    The iterator to store A to shared memory.
    Definition: igemm_traits.h:436
    +
    Definition: gemm_traits.h:52
    +
    Definition: matrix_traits.h:357
    Definition: igemm_traits.h:57
    -
    Definition: igemm_traits.h:221
    -
    Definition: igemm_global_tile.h:50
    -
    float Scalar
    Definition: igemm_traits.h:337
    -
    Definition: gemm_traits.h:428
    -
    Copy< typename Iterator_::Fragment > Transformer
    Definition: igemm_traits.h:225
    -
    Definition: igemm_traits.h:370
    +
    IgemmGlobalTileTraits< GemmOperand::kA, MatrixLayout::kColumnMajor, int8_t const, Shape< 1, GemmConfig_::OutputTile::kD, GemmConfig_::OutputTile::kW >, Shape< 1, ShapeCount< typename GemmConfig_::Warps >::kCount, GemmConfig_::kWarpSize >, GemmConfig_::kScalarsPerLdgA > GlobalTileTraits
    The traits class to build the iterator to load data from global memory for A^N.
    Definition: igemm_traits.h:170
    +
    Definition: igemm_global_tile.h:95
    +
    Definition: igemm_traits.h:374
    +
    float Scalar
    Definition: igemm_traits.h:494
    +
    Definition: gemm_traits.h:349
    +
    Copy< typename Iterator_::Fragment > Transformer
    Definition: igemm_traits.h:378
    +
    Definition: igemm_traits.h:527
    A Shape implementing Layout Concept describing the dimensions of a cube.
    Definition: shape.h:64
    -
    GemmSharedStoreTileAbTraits< int8_t, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/4, GemmConfig_::OutputTile::kW *4 >, typename GlobalTileTraits::Threads, kScalarsPerStsA > SharedStoreTileTraits
    The traits class to build the iterator to store data to shared memory for A^N.
    Definition: igemm_traits.h:171
    - -
    Template performing matrix multiply-add operation within a thread.
    Definition: thread_multiply_add.h:43
    -
    Definition: matrix_traits.h:36
    - -
    IgemmEpilogue< IgemmEpilogueTraits< GemmConfig, EpilogueFunctor_ > > Epilogue
    The epilogue.
    Definition: igemm_traits.h:330
    -
    IgemmTransformerA< GemmTileTraitsHelperA::kLayout, GlobalLoadIteratorA >::Transformer GlobalTransformerA
    The default transformer for A.
    Definition: igemm_traits.h:278
    -
    Kind
    Definition: matrix_traits.h:36
    -
    Definition: igemm_traits.h:236
    -
    TileStoreIterator< typename GemmTileTraitsHelperA::SharedStoreTileTraits, typename GemmTileTraitsHelperA::SharedStoreTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedStoreIteratorA
    The iterator to store A to shared memory.
    Definition: igemm_traits.h:284
    -
    Functor to compute linear combination of fragments.
    Definition: linear_scaling.h:40
    -
    Definition: matrix_traits.h:43
    +
    TileLoadIterator< typename GemmTileTraitsHelperB::SharedLoadTileTraits, typename GemmTileTraitsHelperB::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorB
    The iterator to load B from shared memory.
    Definition: igemm_traits.h:476
    +
    GemmSharedStoreTileAbTraits< int8_t, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/4, GemmConfig_::OutputTile::kH *4 >, typename GlobalTileTraits::Threads, kScalarsPerStsB > SharedStoreTileTraits
    The traits class to build the iterator to store data to shared memory for B^N.
    Definition: igemm_traits.h:368
    +
    ReshapeThreads< VectorizedTile, Threads_ >::Threads Threads
    The threads shape.
    Definition: gemm_global_tile.h:88
    + +
    Template performing matrix multiply-add operation within a thread.
    Definition: thread_multiply_add.h:44
    +
    Definition: matrix_traits.h:159
    + + +
    IgemmConfig< OutputTile_, ScalarD_, ThreadGemmShape_ > GemmConfig
    The IGEMM config.
    Definition: igemm_traits.h:420
    +
    IgemmGlobalIteratorAb< GlobalTileTraits, Index_ > GlobalLoadIterator
    The global load iterator.
    Definition: igemm_traits.h:292
    +
    Kind
    Enumeration defining fundamental contiguous layouts.
    Definition: matrix_traits.h:159
    +
    GemmGlobalIteratorAb< GlobalTileTraits, Index_ > GlobalLoadIterator
    The global load iterator.
    Definition: igemm_traits.h:173
    +
    GemmGlobalIteratorAb< GlobalTileTraits, Index_ > GlobalLoadIterator
    The global load iterator.
    Definition: igemm_traits.h:356
    +
    GemmConfig::MultiplyAdd MultiplyAdd
    The multiply-add functor.
    Definition: igemm_traits.h:482
    +
    Definition: igemm_traits.h:389
    +
    Functor to compute linear combination of fragments.
    Definition: linear_scaling.h:51
    +
    SharedLoadStream< SharedLoadIteratorA, Copy< typename SharedLoadIteratorA::Fragment > > SharedLoadStreamA
    The stream to load A from shared memory.
    Definition: igemm_traits.h:470
    +
    Definition: matrix_traits.h:357
    +
    GlobalLoadStream< GemmOperand::kA, GlobalLoadIteratorA, SharedStoreIteratorA, GlobalTransformerA > GlobalLoadStreamA
    The stream to load A from global memory to shared memory.
    Definition: igemm_traits.h:442
    +
    IgemmTileTraitsHelperA< kLayoutA_, GemmConfig, Index_ > GemmTileTraitsHelperA
    The GEMM config for A.
    Definition: igemm_traits.h:422
    Implements a software-pipelined efficient GEMM.
    -
    ReshapeThreads< Tile, Threads_ >::Threads Threads
    The threads shape.
    Definition: gemm_global_tile.h:87
    +
    GemmSharedLoadTileATraits< int8_t const, typename GemmConfig_::OutputTile, typename GemmConfig_::Warps, typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, typename GemmConfig_::InstructionShape, GemmConfig_::kStages, 16, SharedStoreTileTraits::kSkew > SharedLoadTileTraits
    The traits class to build the iterator to load from shared memory for A^N.
    Definition: igemm_traits.h:253
    +
    SharedLoadStream< SharedLoadIteratorB, Copy< typename SharedLoadIteratorB::Fragment > > SharedLoadStreamB
    The stream to load B from shared memory.
    Definition: igemm_traits.h:479
    Defines structural properties of the GEMM epilogue.
    -
    Definition: igemm_traits.h:336
    +
    Definition: igemm_traits.h:493
    Defines the epilogue phase of the GEMM computation for IGEMM, supporting integer and floating-point o...
    Defines conversion operations among Fragments of different base type.
    -
    GemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_ > Base
    The base config.
    Definition: igemm_traits.h:185
    -
    SharedLoadStream< SharedLoadIteratorA, Copy< typename SharedLoadIteratorA::Fragment > > SharedLoadStreamA
    The stream to load A from shared memory.
    Definition: igemm_traits.h:313
    +
    GemmSharedStoreWithSkewTileAbTraits< int8_t, Shape< GemmConfig_::kStages, GemmConfig_::OutputTile::kD/4, GemmConfig_::OutputTile::kH *4 >, typename GlobalTileTraits::Threads, kScalarsPerStsB, 16 > SharedStoreTileTraits
    The traits class to build the iterator to store data to shared memory for B^N.
    Definition: igemm_traits.h:306
    +
    Implements tile iterators to partition the thread block tile into 2D subtiles and efficiently load ea...
    -
    An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
    Definition: tile_iterator.h:620
    -
    GemmConfig::MultiplyAdd MultiplyAdd
    The multiply-add functor.
    Definition: igemm_traits.h:325
    +
    TileLoadIterator< typename GemmTileTraitsHelperA::SharedLoadTileTraits, typename GemmTileTraitsHelperA::SharedLoadTileTraits::Scalar, IteratorAdvance::kH, MemorySpace::kShared > SharedLoadIteratorA
    The iterator to load A from shared memory.
    Definition: igemm_traits.h:467
    +
    GemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_ > Base
    The base config.
    Definition: igemm_traits.h:152
    +
    An iterator implementing Tile Store Iterator Concept for storing a tile to memory.
    Definition: tile_iterator.h:836
    +
    ClearAccumulators< typename MultiplyAdd::ScalarC > ClearAccumulators
    The object to clear accumulators.
    Definition: igemm_traits.h:484
    diff --git a/docs/index.html b/docs/index.html index f2ba68993..6fab15e10 100644 --- a/docs/index.html +++ b/docs/index.html @@ -75,7 +75,7 @@ $(function() {
    diff --git a/docs/iterator__access_8h.html b/docs/iterator__access_8h.html index cc41cd5af..06fd90ad6 100644 --- a/docs/iterator__access_8h.html +++ b/docs/iterator__access_8h.html @@ -82,10 +82,9 @@ $(function() {

    Free functions for loading and storing to implementations of tile iteartor concepts. More...

    -

    @@ -142,7 +146,7 @@ Namespaces

    @@ -98,76 +97,15 @@ Namespaces Functions - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    template<typename InputIterator , typename Fragment >
    CUTLASS_HOST_DEVICE void cutlass::iterator_load (InputIterator &iterator, Fragment &fragment)
     Loads a fragment from an input iterator. More...
     
    template<typename InputIterator , typename Fragment >
    CUTLASS_DEVICE void cutlass::shared_iterator_load (InputIterator &iterator, Fragment &fragment)
     Loads a fragment from a shared memory input iterator. More...
     
    template<typename InputIterator , typename Fragment >
    CUTLASS_DEVICE void cutlass::shared_iterator_load (InputIterator &iterator, Fragment &fragment, int d)
     Loads a fragment from a shared memory input iterator. More...
     
    template<typename InputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void cutlass::iterator_load_post_increment (InputIterator &iterator, Fragment &fragment, typename InputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
     Loads a fragment from an input iterator, masked by a predicate iterator. More...
     
    template<typename InputIterator , typename Fragment >
    CUTLASS_HOST_DEVICE void cutlass::iterator_load_post_increment (InputIterator &iterator, Fragment &fragment, typename InputIterator::Index offset=0)
     Loads a fragment from an input iterator. More...
     
    template<typename InputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void cutlass::iterator_load_post_increment (InputIterator &iterator, Fragment &fragment, ConstPredicateAdapter pred_it)
     Loads a fragment from an input iterator. More...
     
    template<typename InputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void cutlass::iterator_load (InputIterator const &_iterator, Fragment &fragment, typename InputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
     
    template<typename InputIterator , typename Fragment >
    CUTLASS_HOST_DEVICE void cutlass::iterator_load (InputIterator const &iterator, Fragment &fragment, typename InputIterator::Index offset=0)
     Loads a fragment from an input iterator. More...
     
    template<typename InputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void cutlass::iterator_load (InputIterator const &iterator, Fragment &fragment, ConstPredicateAdapter pred_it)
     Loads a fragment from an input iterator. More...
     
    template<typename OutputIterator , typename Fragment >
    CUTLASS_HOST_DEVICE void cutlass::iterator_store (OutputIterator &iterator, Fragment &fragment)
     Stores a fragment to an output iterator. More...
     
    template<typename OutputIterator , typename Fragment >
    CUTLASS_DEVICE void cutlass::shared_iterator_store (OutputIterator &iterator, Fragment const &fragment)
     Stores a fragment to a shared memory output iterator. More...
     
    template<typename OutputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void cutlass::iterator_store_post_increment (OutputIterator &iterator, Fragment const &fragment, typename OutputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
     Stores a fragment to an output iterator, masked by a predicate iterator. More...
     
    template<typename OutputIterator , typename Fragment >
    CUTLASS_HOST_DEVICE void cutlass::iterator_store_post_increment (OutputIterator &iterator, Fragment const &fragment, typename OutputIterator::Index offset=0)
     Stores a fragment to an output iterator. More...
     
    template<typename OutputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void cutlass::iterator_store_post_increment (OutputIterator &iterator, Fragment const &fragment, ConstPredicateAdapter pred_it)
     Stores a fragment to an output iterator. More...
     
    template<typename OutputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void cutlass::iterator_store (OutputIterator const &_iterator, Fragment const &fragment, typename OutputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
     Stores a fragment to an output iterator, masked by a predicate iterator. More...
     
    template<typename OutputIterator , typename Fragment >
    CUTLASS_HOST_DEVICE void cutlass::iterator_store (OutputIterator const &iterator, Fragment const &fragment, typename OutputIterator::Index offset=0)
     Stores a fragment to an output iterator. More...
     
    template<typename OutputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void cutlass::iterator_store (OutputIterator const &iterator, Fragment const &fragment, ConstPredicateAdapter pred_it)
     Stores a fragment to an output iterator. More...
     
    diff --git a/docs/iterator__access_8h_source.html b/docs/iterator__access_8h_source.html index 11289a933..fac9ea1e5 100644 --- a/docs/iterator__access_8h_source.html +++ b/docs/iterator__access_8h_source.html @@ -76,30 +76,18 @@ $(function() {
    iterator_access.h
    -Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    28 #pragma once
    29 
    31 #include <cutlass/load_store.h>
    33 #include <cutlass/shape.h>
    34 
    35 namespace cutlass {
    36 
    38 
    40 template <typename InputIterator, typename Fragment>
    41 CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment) {
    42  typename InputIterator::FragmentIterator frag_iterator(fragment);
    43  for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
    44  for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
    45  for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
    46  for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
    47  if (iterator.valid(d, h, w, c)) {
    48  int const offset =
    50  0, 0, w, c);
    52  load(reinterpret_cast<typename InputIterator::AccessType &>(
    53  frag_iterator.at(d, h, w, c)),
    54  iterator.data(),
    55  offset);
    56  }
    57  }
    58  if (w < InputIterator::Iterations::kW - 1) {
    59  iterator.inc_w();
    60  }
    61  }
    62  if (h < InputIterator::Iterations::kH - 1) {
    63  iterator.inc_h();
    64  }
    65  }
    66  if (d < InputIterator::Iterations::kD - 1) {
    67  iterator.inc_d();
    68  }
    69  }
    70  iterator.inc_advance();
    71 }
    72 
    74 template <typename InputIterator, typename Fragment>
    75 CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment) {
    76  typename InputIterator::FragmentIterator frag_iterator(fragment);
    77  for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
    78  for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
    79  for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
    80  for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
    81  int const offset =
    83  d, h, w, c);
    84 
    85  FragmentLoad<InputIterator::kIteratorFragment,
    86  InputIterator::Tile::kC,
    87  typename InputIterator::Scalar,
    88  InputIterator::kMemorySpace,
    89  typename InputIterator::FragmentElement,
    90  InputIterator::Tile::kW>::load(frag_iterator.at(d, h, w, c),
    91  iterator.data(),
    92  offset);
    93  }
    94  }
    95  }
    96  }
    97 }
    98 
    100 template <typename InputIterator, typename Fragment>
    101 CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment, int d) {
    102  typename InputIterator::FragmentIterator frag_iterator(fragment);
    103  for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
    104  for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
    105  for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
    106  int const offset =
    108  d, h, w, c);
    109 
    110  FragmentLoad<InputIterator::kIteratorFragment,
    111  InputIterator::Tile::kC,
    112  typename InputIterator::Scalar,
    113  InputIterator::kMemorySpace,
    114  typename InputIterator::FragmentElement,
    115  InputIterator::Tile::kW>::load(frag_iterator.at(0, h, w, c),
    116  iterator.data(),
    117  offset);
    118  }
    119  }
    120  }
    121 }
    122 
    124 template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
    126  Fragment &fragment,
    127  typename InputIterator::Index offset,
    128  ConstPredicateAdapter predicate_adapter) {
    129  for (int d = 0; d < InputIterator::Iterations::kD; ++d, iterator.inc_d()) {
    130  for (int h = 0; h < InputIterator::Iterations::kH; ++h, iterator.inc_h()) {
    131  for (int w = 0; w < InputIterator::Iterations::kW; ++w, iterator.inc_w()) {
    132  if (predicate_adapter.at(d, h, w, 0)) {
    133  int idx = InputIterator::Tile::kC *
    134  (w + InputIterator::Iterations::kW * (h + InputIterator::Iterations::kH * d));
    135 
    137  load(reinterpret_cast<typename InputIterator::AccessType &>(fragment[idx]),
    138  iterator.data(),
    139  offset);
    140  }
    141  }
    142  }
    143  }
    144 }
    145 
    147 template <typename InputIterator, typename Fragment>
    149  Fragment &fragment,
    150  typename InputIterator::Index offset = 0) {
    152  iterator_load_post_increment(iterator, fragment, offset, pred);
    153 }
    154 
    156 template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
    158  Fragment &fragment,
    159  ConstPredicateAdapter pred_it) {
    160  iterator_load_post_increment(iterator, fragment, 0, pred_it);
    161 }
    162 
    163 template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
    164 CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &_iterator,
    165  Fragment &fragment,
    166  typename InputIterator::Index offset,
    167  ConstPredicateAdapter predicate_adapter) {
    168  InputIterator iterator(_iterator);
    169  iterator_load_post_increment(iterator, fragment, offset, predicate_adapter);
    170 }
    171 
    173 template <typename InputIterator, typename Fragment>
    174 CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &iterator,
    175  Fragment &fragment,
    176  typename InputIterator::Index offset = 0) {
    178  iterator_load(iterator, fragment, offset, pred);
    179 }
    180 
    182 template <typename InputIterator, typename Fragment, typename ConstPredicateAdapter>
    183 CUTLASS_HOST_DEVICE void iterator_load(InputIterator const &iterator,
    184  Fragment &fragment,
    185  ConstPredicateAdapter pred_it) {
    186  iterator_load(iterator, fragment, 0, pred_it);
    187 }
    188 
    190 
    192 template <typename OutputIterator, typename Fragment>
    193 CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment) {
    194  typename OutputIterator::FragmentIterator frag_iterator(fragment);
    195  for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
    196  for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
    197  for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
    198  if (iterator.valid(d, h, w, 0)) {
    199  int const offset =
    201  d, h, w, 0);
    202 
    203  Store<typename Fragment::Element,
    204  OutputIterator::Tile::kC,
    205  OutputIterator::kMemorySpace>::
    206  store(reinterpret_cast<typename OutputIterator::AccessType &>(
    207  frag_iterator.at(d, h, w, 0)),
    208  iterator.data(),
    209  offset);
    210  }
    211  if (w < OutputIterator::Iterations::kW - 1) {
    212  iterator.inc_w();
    213  }
    214  }
    215  if (h < OutputIterator::Iterations::kH - 1) {
    216  iterator.inc_h();
    217  }
    218  }
    219  if (d < OutputIterator::Iterations::kD - 1) {
    220  iterator.inc_d();
    221  }
    222  }
    223  iterator.inc_advance();
    224 }
    225 
    227 template <typename OutputIterator, typename Fragment>
    228 CUTLASS_DEVICE void shared_iterator_store(OutputIterator &iterator, Fragment const &fragment) {
    229  typename OutputIterator::FragmentConstIterator frag_iterator(fragment);
    230  for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
    231  for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
    232  for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
    233  for (int c = 0; c < OutputIterator::Iterations::kC; ++c) {
    234  int const offset =
    236  d, h, w, c);
    237 
    238  FragmentStore<OutputIterator::kIteratorFragment,
    239  OutputIterator::Tile::kC,
    240  typename OutputIterator::Scalar,
    241  OutputIterator::kMemorySpace,
    242  typename OutputIterator::FragmentElement,
    243  OutputIterator::Tile::kW>::store(frag_iterator.at(d, h, w, c),
    244  iterator.data(),
    245  offset);
    246  }
    247  }
    248  }
    249  }
    250 }
    251 
    253 
    255 template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
    257  Fragment const &fragment,
    258  typename OutputIterator::Index offset,
    259  ConstPredicateAdapter predicate_adapter) {
    260  for (int d = 0; d < OutputIterator::Iterations::kD; ++d, iterator.inc_d()) {
    261  for (int h = 0; h < OutputIterator::Iterations::kH; ++h, iterator.inc_h()) {
    262  for (int w = 0; w < OutputIterator::Iterations::kW; ++w, iterator.inc_w()) {
    263  if (predicate_adapter.at(d, h, w, 0)) {
    264  int idx = OutputIterator::Tile::kC *
    265  (w + OutputIterator::Iterations::kW * (h + OutputIterator::Iterations::kH * d));
    266 
    267  Store<typename Fragment::Element,
    268  OutputIterator::Tile::kC,
    269  OutputIterator::kMemorySpace>::
    270  store(reinterpret_cast<typename OutputIterator::AccessType const &>(fragment[idx]),
    271  iterator.data(),
    272  offset);
    273  }
    274  }
    275  }
    276  }
    277 }
    278 
    280 template <typename OutputIterator, typename Fragment>
    282  Fragment const &fragment,
    283  typename OutputIterator::Index offset = 0) {
    285  iterator_store_post_increment(iterator, fragment, offset, pred);
    286 }
    287 
    289 template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
    291  Fragment const &fragment,
    292  ConstPredicateAdapter pred_it) {
    293  iterator_store_post_increment(iterator, fragment, 0, pred_it);
    294 }
    295 
    297 template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
    298 CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &_iterator,
    299  Fragment const &fragment,
    300  typename OutputIterator::Index offset,
    301  ConstPredicateAdapter predicate_adapter) {
    302  OutputIterator iterator(_iterator);
    303  iterator_store_post_increment(iterator, fragment, offset, predicate_adapter);
    304 }
    305 
    307 template <typename OutputIterator, typename Fragment>
    308 CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &iterator,
    309  Fragment const &fragment,
    310  typename OutputIterator::Index offset = 0) {
    312  iterator_store(iterator, fragment, offset, pred);
    313 }
    314 
    316 template <typename OutputIterator, typename Fragment, typename ConstPredicateAdapter>
    317 CUTLASS_HOST_DEVICE void iterator_store(OutputIterator const &iterator,
    318  Fragment const &fragment,
    319  ConstPredicateAdapter pred_it) {
    320  iterator_store(iterator, fragment, 0, pred_it);
    321 }
    322 
    324 
    325 } // namespace cutlass
    Definition: fragment_load_store.h:43
    -
    Definition: convert.h:33
    -
    CUTLASS_DEVICE void shared_iterator_load(InputIterator &iterator, Fragment &fragment)
    Loads a fragment from a shared memory input iterator.
    Definition: iterator_access.h:75
    -
    CUTLASS_HOST_DEVICE void iterator_store_post_increment(OutputIterator &iterator, Fragment const &fragment, typename OutputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
    Stores a fragment to an output iterator, masked by a predicate iterator.
    Definition: iterator_access.h:256
    -
    Defines accessors for loading and storing fragments to memory efficiently.
    -
    static CUTLASS_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
    The load function.
    Definition: load_store.h:59
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    28 #pragma once
    29 
    30 #include "cutlass/load_store.h"
    32 #include "cutlass/shape.h"
    33 
    34 namespace cutlass {
    35 
    37 // Used by convolution
    38 template <typename InputIterator, typename Fragment>
    39 CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment) {
    40  typename InputIterator::FragmentIterator frag_iterator(fragment);
    41  for (int d = 0; d < InputIterator::Iterations::kD; ++d) {
    42  for (int h = 0; h < InputIterator::Iterations::kH; ++h) {
    43  for (int w = 0; w < InputIterator::Iterations::kW; ++w) {
    44  for (int c = 0; c < InputIterator::Iterations::kC; ++c) {
    45  if (iterator.valid(d, h, w, c)) {
    46  iterator.load_element(reinterpret_cast<typename InputIterator::AccessType &>(
    47  frag_iterator.at(d, h, w, c)),
    48  d,
    49  h,
    50  w,
    51  c);
    52  }
    53  }
    54  if (w < InputIterator::Iterations::kW - 1) {
    55  iterator.inc_w();
    56  }
    57  }
    58  if (h < InputIterator::Iterations::kH - 1) {
    59  iterator.inc_h();
    60  }
    61  }
    62  if (d < InputIterator::Iterations::kD - 1) {
    63  iterator.inc_d();
    64  }
    65  }
    66  iterator.inc_advance();
    67 }
    68 
    69 template <typename OutputIterator, typename Fragment>
    70 CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment) {
    71  typename OutputIterator::FragmentIterator frag_iterator(fragment);
    72  for (int d = 0; d < OutputIterator::Iterations::kD; ++d) {
    73  for (int h = 0; h < OutputIterator::Iterations::kH; ++h) {
    74  for (int w = 0; w < OutputIterator::Iterations::kW; ++w) {
    75  for (int c = 0; c < OutputIterator::Iterations::kC; ++c) {
    76  if (iterator.valid(d, h, w, c)) {
    77  iterator.store_element(reinterpret_cast<typename OutputIterator::AccessType &>(
    78  frag_iterator.at(d, h, w, c)),
    79  d,
    80  h,
    81  w,
    82  c);
    83  }
    84  }
    85  if (w < OutputIterator::Iterations::kW - 1) {
    86  iterator.inc_w();
    87  }
    88  }
    89  if (h < OutputIterator::Iterations::kH - 1) {
    90  iterator.inc_h();
    91  }
    92  }
    93  if (d < OutputIterator::Iterations::kD - 1) {
    94  iterator.inc_d();
    95  }
    96  }
    97  iterator.inc_advance();
    98 }
    100 
    101 } // namespace cutlass
    Definition: convert.h:33
    A template defining Fragment Concept.
    Definition: fragment.h:99
    -
    Definition: load_store.h:131
    Defines container classes and iterators for managing a statically sized vector of boolean predicates...
    -
    static CUTLASS_DEVICE int get(int d, int h, int w, int c)
    Definition: shape.h:211
    -
    CUTLASS_HOST_DEVICE void iterator_load_post_increment(InputIterator &iterator, Fragment &fragment, typename InputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
    Loads a fragment from an input iterator, masked by a predicate iterator.
    Definition: iterator_access.h:125
    Defines abstractions for efficiently loading and storing vectors to memory.
    #define CUTLASS_HOST_DEVICE
    Definition: cutlass.h:46
    -
    CUTLASS_DEVICE void shared_iterator_store(OutputIterator &iterator, Fragment const &fragment)
    Stores a fragment to a shared memory output iterator.
    Definition: iterator_access.h:228
    -
    Element_ Element
    The element.
    Definition: fragment.h:108
    -
    Always returns true predicate.
    Definition: predicate_vector.h:426
    -
    CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment)
    Stores a fragment to an output iterator.
    Definition: iterator_access.h:193
    -
    Definition: fragment_load_store.h:91
    -
    CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment)
    Loads a fragment from an input iterator.
    Definition: iterator_access.h:41
    +
    CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &fragment)
    Definition: iterator_access.h:70
    +
    CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragment)
    Definition: iterator_access.h:39
    Defines Shape implementing the Layout concept for representing a 4D hypercube of objects.
    diff --git a/docs/kernel__launch_8h.html b/docs/kernel__launch_8h.html new file mode 100644 index 000000000..192d541b3 --- /dev/null +++ b/docs/kernel__launch_8h.html @@ -0,0 +1,108 @@ + + + + + + + +Cutlass: kernel_launch.h File Reference + + + + + + + + + + +
    +
    + + + + + + +
    +
    Cutlass +
    +
    CUDA Templates for Linear Algebra Subroutines and Solvers
    +
    +
    + + + + + + + + +
    +
    + + +
    + +
    + + +
    +
    + +
    +
    kernel_launch.h File Reference
    +
    +
    + +

    Defines structures and helpers to launch CUDA kernels within CUTLASS. +More...

    +
    #include "cutlass/cutlass.h"
    +
    +

    Go to the source code of this file.

    + + + + + +

    +Classes

    struct  cutlass::KernelLaunchConfiguration
     Structure containing the basic launch configuration of a CUDA kernel. More...
     
    + + + +

    +Namespaces

     cutlass
     
    +
    + + + + diff --git a/docs/kernel__launch_8h_source.html b/docs/kernel__launch_8h_source.html new file mode 100644 index 000000000..52c7a5e07 --- /dev/null +++ b/docs/kernel__launch_8h_source.html @@ -0,0 +1,95 @@ + + + + + + + +Cutlass: kernel_launch.h Source File + + + + + + + + + + +
    +
    + + + + + + +
    +
    Cutlass +
    +
    CUDA Templates for Linear Algebra Subroutines and Solvers
    +
    +
    + + + + + + + + +
    +
    + + +
    + +
    + + +
    +
    +
    +
    kernel_launch.h
    +
    +
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    29 #pragma once
    30 
    31 #include "cutlass/cutlass.h"
    32 
    33 namespace cutlass {
    34 
    36 
    39 
    41  dim3 grid;
    42 
    44  dim3 block;
    45 
    47  size_t dynamic_smem;
    48 
    49  //
    50  // Methods
    51  //
    52 
    56  dim3 _grid = dim3(1,1,1),
    57  dim3 _block = dim3(1,1,1),
    58  size_t _dynamic_smem = 0
    59  ):
    60  grid(_grid),
    61  block(_block),
    62  dynamic_smem(_dynamic_smem) { }
    63 };
    64 
    66 
    67 } // namespace cutlass
    CUTLASS_HOST_DEVICE KernelLaunchConfiguration(dim3 _grid=dim3(1, 1, 1), dim3 _block=dim3(1, 1, 1), size_t _dynamic_smem=0)
    Constructs a KernellaunchConfiguration object.
    Definition: kernel_launch.h:55
    +
    Definition: convert.h:33
    +
    Structure containing the basic launch configuration of a CUDA kernel.
    Definition: kernel_launch.h:38
    +
    #define CUTLASS_HOST_DEVICE
    Definition: cutlass.h:46
    +
    size_t dynamic_smem
    Bytes of dynamically allocated SMEM in addition to static SMEM.
    Definition: kernel_launch.h:47
    +
    dim3 block
    CUDA threablock dimensions.
    Definition: kernel_launch.h:44
    +
    dim3 grid
    CUDA grid dimensions.
    Definition: kernel_launch.h:41
    +
    Basic include for CUTLASS macros.
    +
    + + + + diff --git a/docs/linear__scaling_8h.html b/docs/linear__scaling_8h.html index 060be3aa3..132c09d48 100644 --- a/docs/linear__scaling_8h.html +++ b/docs/linear__scaling_8h.html @@ -74,7 +74,8 @@ $(function() {
    linear_scaling.h File Reference
    @@ -82,7 +83,7 @@ $(function() {

    Implements the BLAS linear scaling function alpha*AB + beta*C. More...

    -
    #include <cutlass/fragment_multiply_add.h>
    +

    Go to the source code of this file.

    @@ -101,11 +102,19 @@ Namespaces +
     
     cutlass::gemm
     
    + + + + + +

    +Functions

    template<typename T >
    CUTLASS_DEVICE bool cutlass::gemm::is_zero (T x)
     
    CUTLASS_DEVICE bool cutlass::gemm::is_zero (half x)
     
    diff --git a/docs/linear__scaling_8h_source.html b/docs/linear__scaling_8h_source.html index d9817ed09..b00e58598 100644 --- a/docs/linear__scaling_8h_source.html +++ b/docs/linear__scaling_8h_source.html @@ -76,25 +76,33 @@ $(function() {
    linear_scaling.h
    -Go to the documentation of this file.
    1 
    2 /***************************************************************************************************
    3  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    4  *
    5  * Redistribution and use in source and binary forms, with or without modification, are permitted
    6  * provided that the following conditions are met:
    7  * * Redistributions of source code must retain the above copyright notice, this list of
    8  * conditions and the following disclaimer.
    9  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    10  * conditions and the following disclaimer in the documentation and/or other materials
    11  * provided with the distribution.
    12  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    13  * to endorse or promote products derived from this software without specific prior written
    14  * permission.
    15  *
    16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    17  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    18  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    19  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    20  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    21  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    22  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    23  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    24  *
    25  **************************************************************************************************/
    29 #pragma once
    30 
    32 
    33 namespace cutlass {
    34 namespace gemm {
    35 
    37 
    39 template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_> >
    40 struct LinearScaling {
    41  // The scalar.
    42  typedef Scalar_ Scalar;
    43  // The adapater.
    44  typedef FragmentMultiplyAdd_ FragmentMultiplyAdd;
    45 
    47  struct Params {
    50 
    52  template <typename GemmDesc_>
    53  CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
    54  alpha = desc.alpha;
    55  beta = desc.beta;
    56  return 0;
    57  }
    58  };
    59 
    61  CUTLASS_DEVICE LinearScaling(Params const& params) : alpha(params.alpha), beta(params.beta) {}
    62 
    64  template <typename Fragment_>
    65  CUTLASS_DEVICE void evaluate(Fragment_ const& accum, Fragment_& output) {
    67  mad.multiply(alpha, accum, output);
    68  }
    69 
    71  template <typename Fragment_>
    72  CUTLASS_DEVICE void evaluate(Fragment_ const& accum, Fragment_ const& old, Fragment_& output) {
    74  Fragment_ tmp;
    75  mad.multiply(beta, old, tmp);
    76  mad.multiply_add(alpha, accum, tmp, output);
    77  }
    78 
    81 };
    82 
    84 
    85 } // namespace gemm
    86 } // namespace cutlass
    Definition: convert.h:33
    -
    Scalar alpha
    The alpha/beta scaling params.
    Definition: linear_scaling.h:49
    -
    Scalar alpha
    The alpha/beta scaling factors.
    Definition: linear_scaling.h:80
    -
    CUTLASS_DEVICE LinearScaling(Params const &params)
    Ctor.
    Definition: linear_scaling.h:61
    -
    CUTLASS_DEVICE void evaluate(Fragment_ const &accum, Fragment_ const &old, Fragment_ &output)
    Evaluate the functor.
    Definition: linear_scaling.h:72
    -
    Scalar beta
    Definition: linear_scaling.h:49
    -
    CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
    Initialize the parameters.
    Definition: linear_scaling.h:53
    -
    Scalar beta
    Definition: linear_scaling.h:80
    +Go to the documentation of this file.
    1 
    2 /***************************************************************************************************
    3  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    4  *
    5  * Redistribution and use in source and binary forms, with or without modification, are permitted
    6  * provided that the following conditions are met:
    7  * * Redistributions of source code must retain the above copyright notice, this list of
    8  * conditions and the following disclaimer.
    9  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    10  * conditions and the following disclaimer in the documentation and/or other materials
    11  * provided with the distribution.
    12  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    13  * to endorse or promote products derived from this software without specific prior written
    14  * permission.
    15  *
    16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    17  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    18  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    19  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    20  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    21  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    22  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    23  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    24  *
    25  **************************************************************************************************/
    29 #pragma once
    30 
    32 
    33 namespace cutlass {
    34 namespace gemm {
    35 
    37 
    38 template <typename T>
    39 CUTLASS_DEVICE bool is_zero(T x) {
    40  return x == T(0);
    41 }
    42 
    43 #if !defined(__CUDACC_RTC__) || defined(CUTLASS_NVRTC_HAS_FP16)
    44 CUTLASS_DEVICE bool is_zero(half x) { return reinterpret_cast<int16_t&>(x) == int16_t(0); }
    45 #endif
    46 
    48 
    50 template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
    51 struct LinearScaling {
    52  // The scalar.
    53  typedef Scalar_ Scalar;
    54  // The accumulator Type
    55  typedef typename FragmentMultiplyAdd_::ScalarAccum ScalarAccum;
    56  // The adapater.
    57  typedef FragmentMultiplyAdd_ FragmentMultiplyAdd;
    58 
    60  struct Params {
    63 
    64  //
    65  // Methods
    66  //
    67 
    68  // Constructor
    70  Params(Scalar _alpha = 0, Scalar _beta = 0) : alpha(_alpha), beta(_beta) {}
    71 
    74  alpha = _alpha;
    75  beta = _beta;
    76  return 0;
    77  }
    78 
    80  template <typename GemmDesc_>
    81  CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
    82  alpha = desc.alpha;
    83  beta = desc.beta;
    84  return 0;
    85  }
    86  };
    87 
    88  //
    89  // Data members
    90  //
    91 
    93 
    94  //
    95  // Methods
    96  //
    97 
    99  CUTLASS_DEVICE LinearScaling() { }
    100 
    102  CUTLASS_DEVICE LinearScaling(Params const& _params) : params(_params) {}
    103 
    107  CUTLASS_DEVICE
    108  bool source_required() const {
    109  return !is_zero(params.beta);
    110  }
    111 
    113  template <typename FragmentA_, typename FragmentB_>
    114  CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_& output) {
    116  mad.multiply(params.alpha, accum, output);
    117 
    118  }
    119 
    121  template <typename ScalarAccum, typename ScalarOutput, int size>
    122  CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput *output) {
    123  Fragment<ScalarAccum, size> FragAccum;
    124  Fragment<ScalarOutput, size> FragOutput;
    125 #pragma unroll
    126  for (int i = 0; i < size; i++) {
    127  FragAccum[i] = accum[i];
    128  FragOutput[i] = output[i];
    129  }
    130  evaluate(FragAccum, FragOutput);
    131 #pragma unroll
    132  for (int i = 0; i < size; i++) {
    133  output[i] = FragOutput[i];
    134  }
    135  }
    136 
    138  template <typename FragmentA_, typename FragmentB_>
    139  CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_ const& old, FragmentB_& output) {
    141  FragmentB_ tmp;
    142  mad.multiply(params.beta, old, tmp);
    143  mad.multiply_add(params.alpha, accum, tmp, output);
    144  }
    145 
    147  template <typename ScalarAccum, typename ScalarOutput, int size>
    148  CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput const *old, ScalarOutput *output) {
    149  Fragment<ScalarAccum, size> FragAccum;
    150  Fragment<ScalarOutput, size> FragOutput;
    152 #pragma unroll
    153  for (int i = 0; i < size; i++) {
    154  FragAccum[i] = accum[i];
    155  FragOutput[i] = output[i];
    156  FragOld[i] = old[i];
    157  }
    158  evaluate(FragAccum, FragOld, FragOutput);
    159 #pragma unroll
    160  for (int i = 0; i < size; i++) {
    161  output[i] = FragOutput[i];
    162  }
    163  }
    164 };
    165 
    167 
    168 } // namespace gemm
    169 } // namespace cutlass
    CUTLASS_HOST_DEVICE int initialize(Scalar _alpha, Scalar _beta)
    Initialize the parameters.
    Definition: linear_scaling.h:73
    +
    Definition: convert.h:33
    +
    Scalar alpha
    The alpha/beta scaling params.
    Definition: linear_scaling.h:62
    +
    CUTLASS_DEVICE bool source_required() const
    Definition: linear_scaling.h:108
    +
    CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput *output)
    Evaluate the functor, without using fragment in the API.
    Definition: linear_scaling.h:122
    +
    CUTLASS_DEVICE void evaluate(FragmentA_ const &accum, FragmentB_ const &old, FragmentB_ &output)
    Evaluate the functor.
    Definition: linear_scaling.h:139
    +
    CUTLASS_DEVICE void evaluate(FragmentA_ const &accum, FragmentB_ &output)
    Evaluate the functor.
    Definition: linear_scaling.h:114
    +
    Scalar beta
    Definition: linear_scaling.h:62
    +
    A template defining Fragment Concept.
    Definition: fragment.h:99
    +
    Params params
    Definition: linear_scaling.h:92
    +
    FragmentMultiplyAdd_::ScalarAccum ScalarAccum
    Definition: linear_scaling.h:55
    +
    CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
    Initialize the parameters.
    Definition: linear_scaling.h:81
    Defines multiply-add operations on fragments within a thread.
    -
    FragmentMultiplyAdd_ FragmentMultiplyAdd
    Definition: linear_scaling.h:44
    +
    FragmentMultiplyAdd_ FragmentMultiplyAdd
    Definition: linear_scaling.h:57
    +
    CUTLASS_DEVICE LinearScaling()
    Ctor.
    Definition: linear_scaling.h:99
    +
    CUTLASS_DEVICE bool is_zero(T x)
    Definition: linear_scaling.h:39
    #define CUTLASS_HOST_DEVICE
    Definition: cutlass.h:46
    -
    CUTLASS_DEVICE void evaluate(Fragment_ const &accum, Fragment_ &output)
    Evaluate the functor.
    Definition: linear_scaling.h:65
    -
    The parameters.
    Definition: linear_scaling.h:47
    -
    Functor to compute linear combination of fragments.
    Definition: linear_scaling.h:40
    -
    Scalar_ Scalar
    Definition: linear_scaling.h:42
    +
    CUTLASS_DEVICE LinearScaling(Params const &_params)
    Ctor.
    Definition: linear_scaling.h:102
    +
    The parameters.
    Definition: linear_scaling.h:60
    +
    Functor to compute linear combination of fragments.
    Definition: linear_scaling.h:51
    +
    Scalar_ Scalar
    Definition: linear_scaling.h:53
    +
    CUTLASS_DEVICE void evaluate(ScalarAccum const *accum, ScalarOutput const *old, ScalarOutput *output)
    Evaluate the functor, without using fragment in the API.
    Definition: linear_scaling.h:148
    +
    CUTLASS_HOST_DEVICE Params(Scalar _alpha=0, Scalar _beta=0)
    Definition: linear_scaling.h:70
    diff --git a/docs/linear__scaling__device__ptr_8h.html b/docs/linear__scaling__device__ptr_8h.html new file mode 100644 index 000000000..ad2add1db --- /dev/null +++ b/docs/linear__scaling__device__ptr_8h.html @@ -0,0 +1,114 @@ + + + + + + + +Cutlass: linear_scaling_device_ptr.h File Reference + + + + + + + + + + +
    +
    + + + + + + +
    +
    Cutlass +
    +
    CUDA Templates for Linear Algebra Subroutines and Solvers
    +
    +
    + + + + + + + + +
    +
    + + +
    + +
    + + +
    +
    + +
    +
    linear_scaling_device_ptr.h File Reference
    +
    +
    + +

    Implements the BLAS linear scaling function alpha*AB + beta*C. +More...

    + +

    Go to the source code of this file.

    + + + + + + + +

    +Classes

    struct  cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >
     
    class  cutlass::gemm::LinearScalingDevicePtr< Scalar_, FragmentMultiplyAdd_ >::Params
     The parameters. More...
     
    + + + + + +

    +Namespaces

     cutlass
     
     cutlass::gemm
     
    +
    + + + + diff --git a/docs/linear__scaling__device__ptr_8h_source.html b/docs/linear__scaling__device__ptr_8h_source.html new file mode 100644 index 000000000..2fae588f7 --- /dev/null +++ b/docs/linear__scaling__device__ptr_8h_source.html @@ -0,0 +1,109 @@ + + + + + + + +Cutlass: linear_scaling_device_ptr.h Source File + + + + + + + + + + +
    +
    + + + + + + +
    +
    Cutlass +
    +
    CUDA Templates for Linear Algebra Subroutines and Solvers
    +
    +
    + + + + + + + + +
    +
    + + +
    + +
    + + +
    +
    +
    +
    linear_scaling_device_ptr.h
    +
    +
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    28 #pragma once
    29 
    30 #include "cutlass/cutlass.h"
    33 
    34 namespace cutlass {
    35 
    37 
    38 namespace gemm {
    39 
    41 
    45 template <typename Scalar_, typename FragmentMultiplyAdd_ = FragmentMultiplyAdd<Scalar_, Scalar_> >
    46 struct LinearScalingDevicePtr : public LinearScaling<Scalar_, FragmentMultiplyAdd_> {
    47 
    50 
    51  // The scalar.
    52  typedef typename Base::Scalar Scalar;
    53 
    55  class Params {
    56  private:
    59 
    62 
    63  public:
    64  //
    65  // Methods
    66  //
    67 
    68  // Constructor
    70  Params() {}
    71 
    72  // Constructor
    75  Scalar alpha,
    76  Scalar beta
    77  ):
    78  alpha_(alpha),
    79  beta_(beta) {}
    80 
    81  // Constructor
    84  Scalar const *alpha_ptr,
    85  Scalar const *beta_ptr
    86  ):
    87  alpha_(alpha_ptr),
    88  beta_(alpha_ptr) {}
    89 
    92  Scalar alpha,
    93  Scalar beta) {
    94 
    95  alpha_ = alpha;
    96  beta_ = beta;
    97 
    98  return 0;
    99  }
    100 
    103  Scalar const *alpha,
    104  Scalar const *beta) {
    105 
    106  alpha_ = alpha;
    107  beta_= beta;
    108 
    109  return 0;
    110  }
    111 
    113  template <typename GemmDesc_>
    114  CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc) {
    115 
    116  alpha_ = desc.alpha;
    117  beta_ = desc.beta;
    118 
    119  return 0;
    120  }
    121 
    124  Scalar alpha() const {
    125  return alpha_;
    126  }
    127 
    130  Scalar beta() const {
    131  return beta_;
    132  }
    133  };
    134 
    135  //
    136  // Methods
    137  //
    138 
    141  this->params.alpha = _params.alpha();
    142  this->params.beta = _params.beta();
    143  }
    144 };
    145 
    147 
    148 } // namespace gemm
    149 } // namespace cutlass
    CUTLASS_HOST_DEVICE int initialize(Scalar const *alpha, Scalar const *beta)
    Initialize the parameters.
    Definition: linear_scaling_device_ptr.h:102
    +
    The parameters.
    Definition: linear_scaling_device_ptr.h:55
    +
    Definition: convert.h:33
    +
    CUTLASS_HOST_DEVICE Params(Scalar const *alpha_ptr, Scalar const *beta_ptr)
    Definition: linear_scaling_device_ptr.h:83
    +
    Implements the BLAS linear scaling function alpha*AB + beta*C.
    +
    Implements the BLAS linear scaling function alpha*AB + beta*C.
    +
    CUTLASS_HOST_DEVICE int initialize(Scalar alpha, Scalar beta)
    Initialize the parameters.
    Definition: linear_scaling_device_ptr.h:91
    +
    Params params
    Definition: linear_scaling.h:92
    +
    LinearScaling< Scalar_, FragmentMultiplyAdd_ > Base
    Linear Scaling class used.
    Definition: linear_scaling_device_ptr.h:49
    +
    CUTLASS_HOST_DEVICE Params()
    Definition: linear_scaling_device_ptr.h:70
    +
    #define CUTLASS_HOST_DEVICE
    Definition: cutlass.h:46
    +
    CUTLASS_HOST_DEVICE Params(Scalar alpha, Scalar beta)
    Definition: linear_scaling_device_ptr.h:74
    +
    CUTLASS_HOST_DEVICE Scalar beta() const
    Gets the beta scalar.
    Definition: linear_scaling_device_ptr.h:130
    +
    CUTLASS_HOST_DEVICE LinearScalingDevicePtr(Params const &_params)
    Ctor.
    Definition: linear_scaling_device_ptr.h:140
    +
    CUTLASS_HOST_DEVICE Scalar alpha() const
    Gets the alpha scalar.
    Definition: linear_scaling_device_ptr.h:124
    +
    Definition: linear_scaling_device_ptr.h:46
    +
    Functor to compute linear combination of fragments.
    Definition: linear_scaling.h:51
    +
    Scalar_ Scalar
    Definition: linear_scaling.h:53
    +
    Base::Scalar Scalar
    Definition: linear_scaling_device_ptr.h:52
    +
    Basic include for CUTLASS macros.
    +
    CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const &desc)
    Initialize the parameters.
    Definition: linear_scaling_device_ptr.h:114
    + +
    + + + + diff --git a/docs/load__store_8h.html b/docs/load__store_8h.html index b23ec3cbf..30a4e7334 100644 --- a/docs/load__store_8h.html +++ b/docs/load__store_8h.html @@ -82,7 +82,7 @@ $(function() {

    Defines abstractions for efficiently loading and storing vectors to memory. More...

    -
    #include <cutlass/vector.h>
    +
    #include "cutlass/vector.h"

    Go to the source code of this file.

    @@ -91,25 +91,43 @@ Classes - + + - + - + + - + - + - + - + - + - + - + + + + + + + + + + + + + + + + +
    struct  cutlass::MemorySpace
     Enum to specify which memory space data resides in. More...
     
    struct  cutlass::Load< Scalar_, Lanes_, Memory_, bool, size_t >
    struct  cutlass::FragmentElementType
     Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix. More...
     
    struct  cutlass::Load< Scalar_, Lanes_, Memory_, true, 4 >
    struct  cutlass::Load< Scalar_, kAccessSize, Memory_, kFragmentElementType, FragmentElement_, kStride, size >
     
    struct  cutlass::Load< Scalar_, Lanes_, Memory_, true, 8 >
    struct  cutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2 >
     Partial specialization for 16b loads. More...
     
    struct  cutlass::Load< double, 2, Memory_, true, 16 >
    struct  cutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4 >
     
    struct  cutlass::Load< Scalar_, Lanes_, Memory_, true, 16 >
    struct  cutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8 >
     
    struct  cutlass::Store< Scalar_, Lanes_, Memory_, bool, size_t >
    struct  cutlass::Load< double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16 >
     
    struct  cutlass::Store< Scalar_, Lanes_, Memory_, true, 4 >
    struct  cutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16 >
     
    struct  cutlass::Store< Scalar_, Lanes_, Memory_, true, 8 >
    struct  cutlass::Store< Scalar_, kAccessSize, Memory_, kFragmentElementType, FragmentElement_, kStride, size >
     
    struct  cutlass::Store< double, 2, Memory_, true, 16 >
    struct  cutlass::Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2 >
     
    struct  cutlass::Store< Scalar_, Lanes_, Memory_, true, 16 >
    struct  cutlass::Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4 >
     
    struct  cutlass::Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8 >
     
    struct  cutlass::Store< double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16 >
     
    struct  cutlass::Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16 >
     
    struct  cutlass::Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     
    struct  cutlass::Load< Vector< bin1_t, 32 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     
    struct  cutlass::Load< Vector< int4_t, 8 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     
    struct  cutlass::Load< Vector< uint4_t, 8 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     
    struct  cutlass::Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     
    diff --git a/docs/load__store_8h_source.html b/docs/load__store_8h_source.html index e421cbf27..9fc9c8668 100644 --- a/docs/load__store_8h_source.html +++ b/docs/load__store_8h_source.html @@ -76,41 +76,64 @@ $(function() {
    load_store.h
    -Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    28 #pragma once
    29 
    30 #include <cutlass/vector.h>
    31 
    32 namespace cutlass {
    33 
    35 
    39 struct MemorySpace {
    40  enum Kind {
    41  kGeneric, // Data accessed through pointer dereferencing
    42  kShared, // Data resides in shared memory
    43  kGlobal // Data resides in global memory
    44  };
    45 };
    46 
    48 
    49 template <typename Scalar_,
    50  int Lanes_,
    51  MemorySpace::Kind Memory_,
    52  bool = (Lanes_ > 1),
    53  size_t = (sizeof(Scalar_) * Lanes_)>
    54 struct Load {
    57 
    59  static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
    60  dst = reinterpret_cast<AccessType const*>(&pointer[offset])[0];
    61  }
    62 };
    63 
    65 
    66 template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
    67 struct Load<Scalar_, Lanes_, Memory_, true, 4> {
    70 
    72  static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
    73  dst.registers[0] = reinterpret_cast<uint32_t const*>(&pointer[offset])[0];
    74  }
    75 };
    76 
    78 
    79 template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
    80 struct Load<Scalar_, Lanes_, Memory_, true, 8> {
    83 
    85  static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
    86  uint2 tmp = reinterpret_cast<uint2 const*>(&pointer[offset])[0];
    87  dst.registers[0] = tmp.x;
    88  dst.registers[1] = tmp.y;
    89  }
    90 };
    91 
    93 
    94 template <MemorySpace::Kind Memory_>
    95 struct Load<double, 2, Memory_, true, 16> {
    98 
    100  static CUTLASS_DEVICE void load(AccessType& dst, double const* pointer, int offset) {
    101  double2 tmp = reinterpret_cast<double2 const*>(&pointer[offset])[0];
    102  dst[0] = tmp.x;
    103  dst[1] = tmp.y;
    104  }
    105 };
    106 
    108 
    109 template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
    110 struct Load<Scalar_, Lanes_, Memory_, true, 16> {
    113 
    115  static CUTLASS_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
    116  uint4 tmp = reinterpret_cast<uint4 const*>(&pointer[offset])[0];
    117  dst.registers[0] = tmp.x;
    118  dst.registers[1] = tmp.y;
    119  dst.registers[2] = tmp.z;
    120  dst.registers[3] = tmp.w;
    121  }
    122 };
    123 
    125 
    126 template <typename Scalar_,
    127  int Lanes_,
    128  MemorySpace::Kind Memory_,
    129  bool = (Lanes_ > 1),
    130  size_t = (sizeof(Scalar_) * Lanes_)>
    131 struct Store {
    134 
    136  static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
    137  pointer[offset] = src;
    138  }
    139 };
    140 
    142 
    143 template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
    144 struct Store<Scalar_, Lanes_, Memory_, true, 4> {
    147 
    149  static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
    150  uint32_t* addr = reinterpret_cast<uint32_t*>(&pointer[offset]);
    151  addr[0] = src.registers[0];
    152  }
    153 };
    154 
    156 
    157 template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
    158 struct Store<Scalar_, Lanes_, Memory_, true, 8> {
    161 
    163  static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
    164  uint2* addr = reinterpret_cast<uint2*>(&pointer[offset]);
    165  addr[0] = make_uint2(src.registers[0], src.registers[1]);
    166  }
    167 };
    168 
    170 
    171 template <MemorySpace::Kind Memory_>
    172 struct Store<double, 2, Memory_, true, 16> {
    175 
    177  static CUTLASS_DEVICE void store(AccessType const& src, double* pointer, int offset) {
    178  double2* addr = reinterpret_cast<double2*>(&pointer[offset]);
    179  addr[0] = make_double2(src[0], src[1]);
    180  }
    181 };
    182 
    184 
    185 template <typename Scalar_, int Lanes_, MemorySpace::Kind Memory_>
    186 struct Store<Scalar_, Lanes_, Memory_, true, 16> {
    189 
    191  static CUTLASS_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
    192  uint4* addr = reinterpret_cast<uint4*>(&pointer[offset]);
    193  addr[0] = make_uint4(src.registers[0], src.registers[1], src.registers[2], src.registers[3]);
    194  }
    195 };
    196 
    198 
    199 } // namespace cutlass
    Vectorize< Scalar_, Lanes_ >::Type AccessType
    The output type.
    Definition: load_store.h:188
    -
    Definition: load_store.h:42
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    28 #pragma once
    29 
    30 #include "cutlass/vector.h"
    31 namespace cutlass {
    32 
    34 
    38 struct MemorySpace {
    39  enum Kind {
    40  kGeneric, // Data accessed through pointer dereferencing
    41  kShared, // Data resides in shared memory
    42  kGlobal // Data resides in global memory
    43  };
    44 };
    45 
    49 };
    50 
    52 
    53 template <typename Scalar_,
    54  int kAccessSize,
    55  MemorySpace::Kind Memory_,
    57  typename FragmentElement_ = Scalar_,
    58  int kStride = 1,
    59  size_t size = (sizeof(Scalar_) * kAccessSize)>
    60 struct Load {
    63 
    65  static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
    66  dst = *reinterpret_cast<AccessType const*>(pointer + offset);
    67  }
    68 
    69 };
    70 
    72 
    74 template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_>
    75 struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2> {
    78 
    80  static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
    81  reinterpret_cast<uint16_t&>(dst) = reinterpret_cast<uint16_t const*>(&pointer[offset])[0];
    82  }
    83 };
    84 
    86 
    87 template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
    88 struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4> {
    91 
    93  static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
    94  dst.registers[0] = reinterpret_cast<uint32_t const*>(&pointer[offset])[0];
    95  }
    96 
    97 };
    98 
    100 
    101 template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
    102 struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8> {
    105 
    107  static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
    108  uint2 tmp = reinterpret_cast<uint2 const*>(&pointer[offset])[0];
    109  dst.registers[0] = tmp.x;
    110  dst.registers[1] = tmp.y;
    111  }
    112 };
    113 
    115 
    116 template <MemorySpace::Kind Memory_, int kStride>
    117 struct Load<double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16> {
    120 
    122  static CUTLASS_HOST_DEVICE void load(AccessType& dst, double const* pointer, int offset) {
    123  double2 tmp = reinterpret_cast<double2 const*>(&pointer[offset])[0];
    124  dst[0] = tmp.x;
    125  dst[1] = tmp.y;
    126  }
    127 };
    128 
    130 
    131 #if defined(__CUDACC_VERSION_MAJOR) && __CUDACC_VERSION_MAJOR < 10
    132 // WAR bug in NVCC where the upper and lower half of the register end up being the same
    133 template <MemorySpace::Kind Memory_, int kStride>
    134 struct Load<half, 8, Memory_, FragmentElementType::kScalar, half, kStride, 16> {
    136  typedef typename Vectorize<half, 8>::Type AccessType;
    137 
    139  static CUTLASS_HOST_DEVICE void load(AccessType& dst, half const* pointer, int offset) {
    140  int2 tmp = reinterpret_cast<int2 const*>(&pointer[offset])[0];
    141  dst.registers[0] = tmp.x;
    142  dst.registers[1] = tmp.y;
    143 
    144  tmp = reinterpret_cast<int2 const*>(&pointer[offset + 4])[0];
    145  dst.registers[2] = tmp.x;
    146  dst.registers[3] = tmp.y;
    147  }
    148 };
    149 
    150 #endif
    151 
    153 
    154 template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
    155 struct Load<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16> {
    158 
    160  static CUTLASS_HOST_DEVICE void load(AccessType& dst, Scalar_ const* pointer, int offset) {
    161  uint4 tmp = reinterpret_cast<uint4 const*>(&pointer[offset])[0];
    162  dst.registers[0] = tmp.x;
    163  dst.registers[1] = tmp.y;
    164  dst.registers[2] = tmp.z;
    165  dst.registers[3] = tmp.w;
    166  }
    167 };
    168 
    170 
    171 template <typename Scalar_,
    172  int kAccessSize,
    173  MemorySpace::Kind Memory_,
    175  typename FragmentElement_ = Scalar_,
    176  int kStride = 1,
    177  size_t size = (sizeof(Scalar_) * kAccessSize)>
    178 struct Store {
    181 
    183  static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
    184  pointer[offset] = *reinterpret_cast<Scalar_ const*>(&src);
    185  }
    186 };
    187 
    189 
    190 template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_>
    191 struct Store<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2> {
    194 
    196  static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
    197  uint16_t* addr = reinterpret_cast<uint16_t*>(&pointer[offset]);
    198  addr[0] = reinterpret_cast<uint16_t const&>(src);
    199  }
    200 };
    201 
    203 
    204 template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
    205 struct Store<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4> {
    208 
    210  static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
    211  uint32_t* addr = reinterpret_cast<uint32_t*>(&pointer[offset]);
    212  addr[0] = src.registers[0];
    213  }
    214 };
    215 
    217 
    218 template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
    219 struct Store<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8> {
    222 
    224  static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
    225  uint2* addr = reinterpret_cast<uint2*>(&pointer[offset]);
    226  addr[0] = make_uint2(src.registers[0], src.registers[1]);
    227  }
    228 };
    229 
    231 
    232 template <MemorySpace::Kind Memory_, int kStride>
    233 struct Store<double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16> {
    236 
    238  static CUTLASS_HOST_DEVICE void store(AccessType const& src, double* pointer, int offset) {
    239  double2* addr = reinterpret_cast<double2*>(&pointer[offset]);
    240  addr[0] = make_double2(src[0], src[1]);
    241  }
    242 };
    243 
    245 
    246 template <typename Scalar_, int kAccessSize, MemorySpace::Kind Memory_, int kStride>
    247 struct Store<Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16> {
    250 
    252  static CUTLASS_HOST_DEVICE void store(AccessType const& src, Scalar_* pointer, int offset) {
    253  uint4* addr = reinterpret_cast<uint4*>(&pointer[offset]);
    254  addr[0] = make_uint4(src.registers[0], src.registers[1], src.registers[2], src.registers[3]);
    255  }
    256 };
    257 
    259 
    260 template <typename Scalar_,
    261  int kAccessSize,
    262  MemorySpace::Kind Memory_,
    263  typename FragmentElement_,
    264  int kStride,
    265  size_t size>
    266 struct Load<Scalar_,
    267  kAccessSize,
    268  Memory_,
    269  FragmentElementType::kWmmaMatrix,
    270  FragmentElement_,
    271  kStride,
    272  size> {
    274  typedef FragmentElement_ AccessType;
    275 
    277  static CUTLASS_HOST_DEVICE void load(AccessType& value, Scalar_ const* pointer, int offset) {
    278  value.load(&pointer[offset], kStride);
    279  }
    280 };
    281 
    283 
    284 template <int kAccessSize,
    285  MemorySpace::Kind Memory_,
    286  typename FragmentElement_,
    287  int kStride,
    288  size_t size>
    289 struct Load<Vector<bin1_t, 32>,
    290  kAccessSize,
    291  Memory_,
    292  FragmentElementType::kWmmaMatrix,
    293  FragmentElement_,
    294  kStride,
    295  size> {
    297  typedef FragmentElement_ AccessType;
    298 
    300  static CUTLASS_HOST_DEVICE void load(AccessType& value, Vector<bin1_t, 32> const* pointer,
    301  int offset) {
    302  value.load(&pointer[offset], kStride * 32);
    303  }
    304 };
    305 
    307 
    308 template <int kAccessSize,
    309  MemorySpace::Kind Memory_,
    310  typename FragmentElement_,
    311  int kStride,
    312  size_t size>
    313 struct Load<Vector<int4_t, 8>,
    314  kAccessSize,
    315  Memory_,
    316  FragmentElementType::kWmmaMatrix,
    317  FragmentElement_,
    318  kStride,
    319  size> {
    321  typedef FragmentElement_ AccessType;
    322 
    324  static CUTLASS_HOST_DEVICE void load(AccessType& value, Vector<int4_t, 8> const* pointer,
    325  int offset) {
    326  value.load(&pointer[offset], kStride * 8);
    327  }
    328 };
    329 
    331 
    332 template <int kAccessSize,
    333  MemorySpace::Kind Memory_,
    334  typename FragmentElement_,
    335  int kStride,
    336  size_t size>
    337 struct Load<Vector<uint4_t, 8>,
    338  kAccessSize,
    339  Memory_,
    340  FragmentElementType::kWmmaMatrix,
    341  FragmentElement_,
    342  kStride,
    343  size> {
    345  typedef FragmentElement_ AccessType;
    346 
    348  static CUTLASS_HOST_DEVICE void load(AccessType& value, Vector<uint4_t, 8> const* pointer,
    349  int offset) {
    350  value.load(&pointer[offset], kStride * 8);
    351  }
    352 };
    353 
    355 template <typename Scalar_,
    356  int kAccessSize,
    357  MemorySpace::Kind Memory_,
    358  typename FragmentElement_,
    359  int kStride,
    360  size_t size>
    361 struct Store<Scalar_,
    362  kAccessSize,
    363  Memory_,
    364  FragmentElementType::kWmmaMatrix,
    365  FragmentElement_,
    366  kStride,
    367  size> {
    369  typedef FragmentElement_ AccessType;
    370 
    372  static CUTLASS_HOST_DEVICE void store(AccessType const& value, Scalar_* pointer, int offset) {
    373  value.store(&pointer[offset], kStride);
    374  }
    375 };
    376 
    378 
    379 } // namespace cutlass
    static CUTLASS_HOST_DEVICE void load(AccessType &value, Vector< bin1_t, 32 > const *pointer, int offset)
    The load function.
    Definition: load_store.h:300
    +
    Vectorize< Scalar_, kAccessSize >::Type AccessType
    The output type.
    Definition: load_store.h:157
    +
    Vectorize< Scalar_, kAccessSize >::Type AccessType
    The output type.
    Definition: load_store.h:77
    +
    static CUTLASS_HOST_DEVICE void store(AccessType const &src, double *pointer, int offset)
    The store function.
    Definition: load_store.h:238
    +
    static CUTLASS_HOST_DEVICE void load(AccessType &value, Vector< int4_t, 8 > const *pointer, int offset)
    The load function.
    Definition: load_store.h:324
    +
    Definition: load_store.h:41
    Definition: convert.h:33
    -
    static CUTLASS_DEVICE void store(AccessType const &src, Scalar_ *pointer, int offset)
    The store function.
    Definition: load_store.h:163
    -
    Enum to specify which memory space data resides in.
    Definition: load_store.h:39
    -
    Definition: load_store.h:43
    -
    static CUTLASS_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
    The load function.
    Definition: load_store.h:59
    -
    Vectorize< Scalar_, Lanes_ >::Type AccessType
    The output type.
    Definition: load_store.h:112
    -
    Vectorize< Scalar_, Lanes_ >::Type AccessType
    The output type.
    Definition: load_store.h:146
    -
    Kind
    Definition: load_store.h:40
    -
    Definition: load_store.h:131
    -
    static CUTLASS_DEVICE void store(AccessType const &src, Scalar_ *pointer, int offset)
    The store function.
    Definition: load_store.h:136
    -
    uint32_t registers[kRegisters]
    The data in registers.
    Definition: vector.h:80
    -
    Vectorize< double, 2 >::Type AccessType
    The output type.
    Definition: load_store.h:174
    -
    Definition: load_store.h:41
    -
    static CUTLASS_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
    The store function.
    Definition: load_store.h:72
    -
    Vectorize< Scalar_, Lanes_ >::Type AccessType
    The output type.
    Definition: load_store.h:133
    -
    Definition: vector.h:61
    -
    static CUTLASS_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
    The store function.
    Definition: load_store.h:85
    -
    Definition: load_store.h:54
    -
    Vectorize< Scalar_, Lanes_ >::Type AccessType
    The output type.
    Definition: load_store.h:82
    + +
    Definition: numeric_types.h:39
    +
    Enum to specify which memory space data resides in.
    Definition: load_store.h:38
    +
    static CUTLASS_HOST_DEVICE void store(AccessType const &src, Scalar_ *pointer, int offset)
    The store function.
    Definition: load_store.h:196
    +
    static CUTLASS_HOST_DEVICE void store(AccessType const &src, Scalar_ *pointer, int offset)
    The store function.
    Definition: load_store.h:252
    + +
    Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix.
    Definition: load_store.h:47
    +
    Definition: load_store.h:42
    + + +
    Vectorize< double, 2 >::Type AccessType
    The output type.
    Definition: load_store.h:119
    +
    Vectorize< FragmentElement_, kAccessSize >::Type AccessType
    The output type.
    Definition: load_store.h:180
    +
    Kind
    Definition: load_store.h:39
    +
    Definition: load_store.h:178
    +
    static CUTLASS_HOST_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
    The load function.
    Definition: load_store.h:160
    +
    uint32_t registers[kRegisters]
    The data in registers.
    Definition: vector.h:81
    +
    static CUTLASS_HOST_DEVICE void load(AccessType &value, Scalar_ const *pointer, int offset)
    The load function.
    Definition: load_store.h:277
    +
    Vectorize< Scalar_, kAccessSize >::Type AccessType
    The output type.
    Definition: load_store.h:193
    +
    Vectorize< Scalar_, kAccessSize >::Type AccessType
    The output type.
    Definition: load_store.h:104
    +
    Kind
    Definition: load_store.h:48
    +
    Definition: load_store.h:40
    +
    #define CUTLASS_HOST_DEVICE
    Definition: cutlass.h:46
    +
    Vectorize< Scalar_, kAccessSize >::Type AccessType
    The output type.
    Definition: load_store.h:62
    +
    static CUTLASS_HOST_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
    The load function.
    Definition: load_store.h:107
    +
    Definition: vector.h:62
    +
    Definition: load_store.h:60
    + +
    static CUTLASS_HOST_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
    The load function.
    Definition: load_store.h:93
    +
    Definition: load_store.h:48
    +
    Vector< Element_, kLanes_ > Type
    Definition: vector.h:271
    Defines a 1D vector of elements held in the registers of each thread.
    -
    Vectorize< Scalar_, Lanes_ >::Type AccessType
    The output type.
    Definition: load_store.h:160
    -
    static CUTLASS_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
    The store function.
    Definition: load_store.h:115
    -
    Vectorize< Scalar_, Lanes_ >::Type AccessType
    The output type.
    Definition: load_store.h:69
    -
    static CUTLASS_DEVICE void load(AccessType &dst, double const *pointer, int offset)
    The store function.
    Definition: load_store.h:100
    -
    Vectorize< double, 2 >::Type AccessType
    The output type.
    Definition: load_store.h:97
    -
    Vectorize< Scalar_, Lanes_ >::Type AccessType
    The output type.
    Definition: load_store.h:56
    -
    static CUTLASS_DEVICE void store(AccessType const &src, Scalar_ *pointer, int offset)
    The store function.
    Definition: load_store.h:191
    -
    static CUTLASS_DEVICE void store(AccessType const &src, Scalar_ *pointer, int offset)
    The store function.
    Definition: load_store.h:149
    -
    static CUTLASS_DEVICE void store(AccessType const &src, double *pointer, int offset)
    The store function.
    Definition: load_store.h:177
    +
    Vectorize< Scalar_, kAccessSize >::Type AccessType
    The output type.
    Definition: load_store.h:249
    +
    static CUTLASS_HOST_DEVICE void load(AccessType &value, Vector< uint4_t, 8 > const *pointer, int offset)
    The load function.
    Definition: load_store.h:348
    +
    Definition: numeric_types.h:43
    +
    Vectorize< Scalar_, kAccessSize >::Type AccessType
    The output type.
    Definition: load_store.h:90
    +
    static CUTLASS_HOST_DEVICE void store(AccessType const &src, Scalar_ *pointer, int offset)
    The store function.
    Definition: load_store.h:183
    +
    static CUTLASS_HOST_DEVICE void store(AccessType const &src, Scalar_ *pointer, int offset)
    The store function.
    Definition: load_store.h:224
    +
    Vectorize< Scalar_, kAccessSize >::Type AccessType
    The output type.
    Definition: load_store.h:221
    +
    Vectorize< Scalar_, kAccessSize >::Type AccessType
    The output type.
    Definition: load_store.h:207
    +
    static CUTLASS_HOST_DEVICE void store(AccessType const &value, Scalar_ *pointer, int offset)
    The store function.
    Definition: load_store.h:372
    +
    Definition: numeric_types.h:41
    + +
    static CUTLASS_HOST_DEVICE void store(AccessType const &src, Scalar_ *pointer, int offset)
    The store function.
    Definition: load_store.h:210
    +
    Vectorize< double, 2 >::Type AccessType
    The output type.
    Definition: load_store.h:235
    +
    static CUTLASS_HOST_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
    The load function.
    Definition: load_store.h:65
    +
    static CUTLASS_HOST_DEVICE void load(AccessType &dst, Scalar_ const *pointer, int offset)
    The load function.
    Definition: load_store.h:80
    +
    static CUTLASS_HOST_DEVICE void load(AccessType &dst, double const *pointer, int offset)
    The load function.
    Definition: load_store.h:122
    diff --git a/docs/matrix__traits_8h.html b/docs/matrix__traits_8h.html index f83c89f0d..097d5f4ae 100644 --- a/docs/matrix__traits_8h.html +++ b/docs/matrix__traits_8h.html @@ -82,17 +82,39 @@ $(function() {

    Defines properties of matrices used to denote layout and operands to GEMM kernels. More...

    - +
    #include "cutlass/coord.h"
    +

    Go to the source code of this file.

    @@ -120,7 +138,7 @@ Namespaces

    + + - + + + + + + + + + + + + + + + + + + + +

    Classes

    struct  cutlass::MatrixCoord
     
    struct  cutlass::MatrixLayout
     Describes layouts of matrices. More...
     Defines data layouts of various matrix formats usable by TensorRef and other classes. More...
     
    struct  cutlass::MatrixLayout::RowMajor
     Mapping function for row-major matrices. More...
     
    struct  cutlass::MatrixLayout::ColumnMajor
     Mapping function for column-major matrices. More...
     
    struct  cutlass::MatrixLayout::RowMajorInterleaved< Interleave >
     
    struct  cutlass::MatrixLayout::ColumnMajorInterleaved< Interleave >
     
    struct  cutlass::MatrixLayout::ContiguousLayout
     
    struct  cutlass::MatrixLayout::ColumnMajorBlockLinear< BlockRows, BlockColumns >
     
    struct  cutlass::MatrixLayout::RowMajorBlockLinear< BlockRows, BlockColumns >
     
    struct  cutlass::GemmOperand
     Gemm operand - D = A * B + C. More...
     
    struct  cutlass::MatrixTransform
     Transformation applied to matrix operands. More...
     
    @@ -102,7 +124,7 @@ Namespaces diff --git a/docs/matrix__traits_8h_source.html b/docs/matrix__traits_8h_source.html index 9f8de2dc6..2e78c5a2b 100644 --- a/docs/matrix__traits_8h_source.html +++ b/docs/matrix__traits_8h_source.html @@ -76,21 +76,88 @@ $(function() {
    matrix_traits.h
    -Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    28 #pragma once
    29 
    30 namespace cutlass {
    31 
    33 
    35 struct MatrixLayout {
    37 };
    38 
    40 
    42 struct GemmOperand {
    43  enum Kind { kA, kB, kC, kD };
    44 };
    45 
    47 
    48 } // namespace cutlass
    Definition: convert.h:33
    -
    Definition: matrix_traits.h:43
    -
    Describes layouts of matrices.
    Definition: matrix_traits.h:35
    -
    Definition: matrix_traits.h:36
    -
    Definition: matrix_traits.h:43
    -
    Gemm operand - D = A * B + C.
    Definition: matrix_traits.h:42
    -
    Definition: matrix_traits.h:36
    -
    Kind
    Definition: matrix_traits.h:36
    -
    Kind
    Definition: matrix_traits.h:43
    -
    Definition: matrix_traits.h:43
    -
    Definition: matrix_traits.h:43
    +Go to the documentation of this file.
    1 /***************************************************************************************************
    2  * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved.
    3  *
    4  * Redistribution and use in source and binary forms, with or without modification, are permitted
    5  * provided that the following conditions are met:
    6  * * Redistributions of source code must retain the above copyright notice, this list of
    7  * conditions and the following disclaimer.
    8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
    9  * conditions and the following disclaimer in the documentation and/or other materials
    10  * provided with the distribution.
    11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
    12  * to endorse or promote products derived from this software without specific prior written
    13  * permission.
    14  *
    15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
    16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
    17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
    18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
    19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
    20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
    21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
    22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
    23  *
    24  **************************************************************************************************/
    28 #pragma once
    29 
    30 #include "cutlass/coord.h"
    31 
    32 namespace cutlass {
    33 
    35 
    38 struct MatrixCoord : public Coord<2, int> {
    39 
    41  typedef int Index;
    42 
    45 
    47  static int const kRow = 0;
    48 
    50  static int const kColumn = 1;
    51 
    52  //
    53  // Methods
    54  //
    55 
    59 
    62  MatrixCoord(Coord<2, Index> const &coord): Base(coord) { }
    63 
    67 
    70  Index const & row() const { return this->at(kRow); }
    71 
    74  Index & row() { return this->at(kRow); }
    75 
    78  Index const & column() const { return this->at(kColumn); }
    79 
    82  Index & column() { return this->at(kColumn); }
    83 
    84  //
    85  // Coord operators
    86  //
    87 
    90  MatrixCoord operator+(Base const& b) const {
    91  return MatrixCoord(Base::operator+(b));
    92  }
    93 
    96  MatrixCoord operator-(Base const& b) const {
    97  return MatrixCoord(Base::operator-(b));
    98  }
    99 
    102  MatrixCoord operator*(Base const& b) const {
    103  return MatrixCoord(Base::operator*(b));
    104  }
    105 
    108  MatrixCoord operator/(Base const& b) const {
    109  return MatrixCoord(Base::operator/(b));
    110  }
    111 
    115  Base::operator+=(b);
    116  return *this;
    117  }
    118 
    122  Base::operator-=(b);
    123  return *this;
    124  }
    125 
    129  Base::operator*=(b);
    130  return *this;
    131  }
    132 
    136  Base::operator/=(b);
    137  return *this;
    138  }
    139 };
    140 
    142 
    144 //
    145 // The following define classes satisfying the TensorRefMapFunc concept. These must support the
    146 // following operations, where func is an instance of type TensorRefMapFunc.
    147 //
    148 // Coord<TensorRefMapFunc::kStorageRank> = func(Coord<kRank>);
    149 //
    150 // Though not required to be usable by TensorRef, each of the following also define a helper
    151 // function to map the "leading dimension" to an appropriate stride vector. Implementations
    152 // following this convention should also implement the following static method:
    153 //
    154 // Coord<TensorRefMapFunc::kStorageRank> stride = TensorRefMapFunc::stride(leading_dim);
    155 //
    156 struct MatrixLayout {
    157 
    160 
    161  //
    162  // TensorRefMapFunc definitions for common layouts
    163  //
    164 
    166  struct RowMajor {
    167  static int const kStorageRank = 2;
    171  return coord;
    172  }
    173  };
    174 
    176  struct ColumnMajor {
    177  static int const kStorageRank = 2;
    181  return make_Coord(coord.column(), coord.row());
    182  }
    183  };
    184 
    187  template <int Interleave>
    189 
    191  static int const kStorageRank = 3;
    192 
    194  static int const kInterleave = Interleave;
    195 
    199  return make_Coord(
    200  coord.row() / kInterleave,
    201  coord.column(),
    202  coord.row() % kInterleave
    203  );
    204  }
    205 
    208  static Coord<kStorageRank> stride(int ldm) {
    209  return make_Coord(
    210  ldm * kInterleave,
    211  kInterleave,
    212  1
    213  );
    214  }
    215  };
    216 
    219  template <int Interleave>
    221 
    223  static int const kStorageRank = 3;
    224 
    226  static int const kInterleave = Interleave;
    227 
    231  return make_Coord(
    232  coord.column() / kInterleave,
    233  coord.row(),
    234  coord.column() % kInterleave
    235  );
    236  }
    237 
    240  static Coord<kStorageRank> stride(int ldm) {
    241  return make_Coord(
    242  ldm * kInterleave,
    243  kInterleave,
    244  1
    245  );
    246  }
    247  };
    248 
    253  static int const kStorageRank = 3;
    254 
    256  static int const kRow = 0;
    257 
    259  static int const kColumn = 1;
    260 
    265  return make_Coord(coord.row(), coord.column(), 0);
    266  }
    267 
    271  if (layout == MatrixLayout::kRowMajor) {
    272  return make_Coord(ldm, 1, 1);
    273  }
    274  return make_Coord(1, ldm, 1);
    275  }
    276  };
    277 
    280  template <int BlockRows, int BlockColumns>
    282 
    284  static int const kStorageRank = 4;
    285 
    287  static int const kBlockRows = BlockRows;
    288 
    290  static int const kBlockColumns = BlockColumns;
    291 
    295  return make_Coord(
    296  coord.column() / kBlockColumns,
    297  coord.row() / kBlockRows,
    298  coord.column() % kBlockColumns,
    299  coord.row() % kBlockRows
    300  );
    301  }
    302 
    305  static Coord<kStorageRank> stride(int ldm) {
    306  return make_Coord(
    307  ldm * kBlockRows * kBlockColumns,
    309  kBlockRows,
    310  1
    311  );
    312  }
    313  };
    314 
    317  template <int BlockRows, int BlockColumns>
    319 
    321  static int const kStorageRank = 4;
    322 
    324  static int const kBlockRows = BlockRows;
    325 
    327  static int const kBlockColumns = BlockColumns;
    328 
    332  return make_Coord(
    333  coord.row() / kBlockRows,
    334  coord.column() / kBlockColumns,
    335  coord.row() % kBlockRows,
    336  coord.column() % kBlockColumns
    337  );
    338  }
    339 
    342  static Coord<kStorageRank> stride(int ldm) {
    343  return make_Coord(
    344  ldm * kBlockRows * kBlockColumns,
    347  1
    348  );
    349  }
    350  };
    351 };
    352 
    354 
    356 struct GemmOperand {
    357  enum Kind { kA, kB, kC, kD };
    358 };
    359 
    361 
    364  enum Kind {
    367  };
    368 };
    369 
    371 
    372 } // namespace cutlass
    int Index
    Integer-valued index.
    Definition: matrix_traits.h:41
    +
    Mapping function for column-major matrices.
    Definition: matrix_traits.h:176
    +
    static int const kBlockColumns
    Interleaving size in columns dimension.
    Definition: matrix_traits.h:327
    +
    Definition: convert.h:33
    +
    CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
    Maps (row, col) to (col, row, col)
    Definition: matrix_traits.h:230
    +
    CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
    Maps (i, j) to (i, j)
    Definition: matrix_traits.h:170
    +
    Transformation applied to matrix operands.
    Definition: matrix_traits.h:363
    +
    Definition: matrix_traits.h:188
    +
    static int const kBlockColumns
    Interleaving size in columns dimension.
    Definition: matrix_traits.h:290
    +
    Definition: matrix_traits.h:365
    +
    Definition: matrix_traits.h:281
    +
    Definition: matrix_traits.h:220
    +
    A Coord is a coordinate of arbitrary rank into a tensor or matrix.
    +
    CUTLASS_HOST_DEVICE Coord< 1 > make_Coord(int _0)
    Helper to make a 2-element coordinate.
    Definition: coord.h:318
    +
    no operation
    Definition: matrix_traits.h:366
    +
    CUTLASS_HOST_DEVICE MatrixCoord & operator/=(Base const &b)
    In-place division.
    Definition: matrix_traits.h:135
    +
    static int const kStorageRank
    Definition: matrix_traits.h:167
    +
    Definition: matrix_traits.h:251
    +
    CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
    Maps (i, j) to (j, i)
    Definition: matrix_traits.h:180
    +
    Kind
    Definition: matrix_traits.h:364
    +
    Coord< 2, Index > Base
    Base type is a Coord of rank=2.
    Definition: matrix_traits.h:44
    +
    CUTLASS_HOST_DEVICE MatrixCoord operator+(Base const &b) const
    Element-wise addition.
    Definition: matrix_traits.h:90
    +
    CUTLASS_HOST_DEVICE Coord & operator*=(Coord const &b)
    In-place multiplication.
    Definition: coord.h:197
    +
    Definition: matrix_traits.h:357
    +
    static int const kRow
    Dimension of rows.
    Definition: matrix_traits.h:256
    +
    static int const kStorageRank
    Definition: matrix_traits.h:177
    +
    static int const kBlockRows
    Interleaving size in rows dimension.
    Definition: matrix_traits.h:287
    +
    Defines data layouts of various matrix formats usable by TensorRef and other classes.
    Definition: matrix_traits.h:156
    +
    static int const kInterleave
    Interleaving size.
    Definition: matrix_traits.h:194
    +
    Definition: matrix_traits.h:159
    +
    CUTLASS_HOST_DEVICE Index const & column() const
    Returns the column of the coordinate.
    Definition: matrix_traits.h:78
    +
    CUTLASS_HOST_DEVICE MatrixCoord(Index row, Index column)
    Helper to construct from a row and column.
    Definition: matrix_traits.h:66
    +
    static CUTLASS_HOST_DEVICE Coord< kStorageRank > stride(int ldm)
    Helper to compute stride vector from leading dimension.
    Definition: matrix_traits.h:208
    +
    static int const kColumn
    Dimension of columns.
    Definition: matrix_traits.h:259
    +
    static int const kStorageRank
    Rank of storage n-D array.
    Definition: matrix_traits.h:191
    +
    CUTLASS_HOST_DEVICE Coord & operator-=(Coord const &b)
    In-place subtraction.
    Definition: coord.h:188
    +
    static int const kStorageRank
    Arbitrary storage rank.
    Definition: matrix_traits.h:253
    +
    Definition: matrix_traits.h:357
    +
    CUTLASS_HOST_DEVICE Coord & operator+=(Coord const &b)
    In-place addition.
    Definition: coord.h:179
    +
    CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
    Maps (row, col) to (row, col, row)
    Definition: matrix_traits.h:198
    +
    #define CUTLASS_HOST_DEVICE
    Definition: cutlass.h:46
    +
    static int const kBlockRows
    Interleaving size in rows dimension.
    Definition: matrix_traits.h:324
    +
    CUTLASS_HOST_DEVICE Index const & row() const
    Returns the row of the coordinate.
    Definition: matrix_traits.h:70
    +
    CUTLASS_HOST_DEVICE Index & at()
    Gets the index of a given Coord element.
    Definition: coord.h:240
    +
    CUTLASS_HOST_DEVICE Coord & operator/=(Coord const &b)
    In-place division.
    Definition: coord.h:206
    +
    CUTLASS_HOST_DEVICE MatrixCoord operator-(Base const &b) const
    Element-wise subtraction.
    Definition: matrix_traits.h:96
    +
    CUTLASS_HOST_DEVICE MatrixCoord(Coord< 2, Index > const &coord)
    Constructs from Coord<2>
    Definition: matrix_traits.h:62
    +
    static int const kStorageRank
    Rank of storage n-D array.
    Definition: matrix_traits.h:321
    +
    Statically-sized array specifying Coords within a tensor.
    Definition: coord.h:49
    +
    Gemm operand - D = A * B + C.
    Definition: matrix_traits.h:356
    +
    static CUTLASS_HOST_DEVICE Coord< kStorageRank > stride(int ldm)
    Helper to compute stride vector from leading dimension.
    Definition: matrix_traits.h:342
    +
    static int const kRow
    Rows dimension.
    Definition: matrix_traits.h:47
    +
    CUTLASS_HOST_DEVICE MatrixCoord & operator-=(Base const &b)
    In-place subtraction.
    Definition: matrix_traits.h:121
    +
    CUTLASS_HOST_DEVICE MatrixCoord operator*(Base const &b) const
    Element-wise multiplication.
    Definition: matrix_traits.h:102
    +
    Definition: matrix_traits.h:159
    +
    static CUTLASS_HOST_DEVICE Coord< kStorageRank > stride(int ldm)
    Helper to compute stride vector from leading dimension.
    Definition: matrix_traits.h:240
    +
    CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
    Definition: matrix_traits.h:264
    +
    Kind
    Enumeration defining fundamental contiguous layouts.
    Definition: matrix_traits.h:159
    +
    CUTLASS_HOST_DEVICE Index & row()
    Returns the row of the coordinate.
    Definition: matrix_traits.h:74
    +
    static int const kStorageRank
    Rank of storage n-D array.
    Definition: matrix_traits.h:284
    +
    static int const kInterleave
    Interleaving size.
    Definition: matrix_traits.h:226
    +
    CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
    Maps (row, col) to (row, col, row, col)
    Definition: matrix_traits.h:331
    +
    Kind
    Definition: matrix_traits.h:357
    +
    Definition: matrix_traits.h:357
    +
    CUTLASS_HOST_DEVICE Index & column()
    Returns the column of the coordinate.
    Definition: matrix_traits.h:82
    +
    CUTLASS_HOST_DEVICE MatrixCoord & operator*=(Base const &b)
    In-place multiplication.
    Definition: matrix_traits.h:128
    +
    static int const kStorageRank
    Rank of storage n-D array.
    Definition: matrix_traits.h:223
    +
    Definition: matrix_traits.h:318
    +
    CUTLASS_HOST_DEVICE MatrixCoord & operator+=(Base const &b)
    In-place addition.
    Definition: matrix_traits.h:114
    +
    static CUTLASS_HOST_DEVICE Coord< kStorageRank > stride(int ldm)
    Helper to compute stride vector from leading dimension.
    Definition: matrix_traits.h:305
    +
    CUTLASS_HOST_DEVICE Coord< kStorageRank > operator()(MatrixCoord const &coord) const
    Maps (row, col) to (col, row, col, row)
    Definition: matrix_traits.h:294
    +
    static CUTLASS_HOST_DEVICE Coord< kStorageRank > stride(MatrixLayout::Kind layout, int ldm)
    Helper to construct a stride vector based on contiguous matrix layout and leading dimension...
    Definition: matrix_traits.h:270
    +
    Definition: matrix_traits.h:38
    +
    CUTLASS_HOST_DEVICE MatrixCoord operator/(Base const &b) const
    Element-wise division.
    Definition: matrix_traits.h:108
    +
    static int const kColumn
    Columns dimension.
    Definition: matrix_traits.h:50
    +
    CUTLASS_HOST_DEVICE MatrixCoord()
    Default ctor.
    Definition: matrix_traits.h:58
    +
    Definition: matrix_traits.h:357
    +
    Mapping function for row-major matrices.
    Definition: matrix_traits.h:166
    diff --git a/docs/menudata.js b/docs/menudata.js index 725988aa8..dde1bbfea 100644 --- a/docs/menudata.js +++ b/docs/menudata.js @@ -29,24 +29,33 @@ var menudata={children:[ {text:"Namespace Members",url:"namespacemembers.html",children:[ {text:"All",url:"namespacemembers.html",children:[ {text:"_",url:"namespacemembers.html#index__"}, +{text:"a",url:"namespacemembers.html#index_a"}, {text:"c",url:"namespacemembers.html#index_c"}, +{text:"e",url:"namespacemembers.html#index_e"}, {text:"f",url:"namespacemembers.html#index_f"}, {text:"g",url:"namespacemembers.html#index_g"}, {text:"i",url:"namespacemembers.html#index_i"}, {text:"l",url:"namespacemembers.html#index_l"}, {text:"m",url:"namespacemembers.html#index_m"}, +{text:"n",url:"namespacemembers.html#index_n"}, {text:"o",url:"namespacemembers.html#index_o"}, +{text:"p",url:"namespacemembers.html#index_p"}, {text:"r",url:"namespacemembers.html#index_r"}, {text:"s",url:"namespacemembers.html#index_s"}, {text:"t",url:"namespacemembers.html#index_t"}]}, {text:"Functions",url:"namespacemembers_func.html",children:[ {text:"_",url:"namespacemembers_func.html#index__"}, +{text:"a",url:"namespacemembers_func.html#index_a"}, {text:"c",url:"namespacemembers_func.html#index_c"}, +{text:"e",url:"namespacemembers_func.html#index_e"}, +{text:"f",url:"namespacemembers_func.html#index_f"}, {text:"g",url:"namespacemembers_func.html#index_g"}, {text:"i",url:"namespacemembers_func.html#index_i"}, {text:"l",url:"namespacemembers_func.html#index_l"}, {text:"m",url:"namespacemembers_func.html#index_m"}, +{text:"n",url:"namespacemembers_func.html#index_n"}, {text:"o",url:"namespacemembers_func.html#index_o"}, +{text:"p",url:"namespacemembers_func.html#index_p"}, {text:"r",url:"namespacemembers_func.html#index_r"}, {text:"s",url:"namespacemembers_func.html#index_s"}]}, {text:"Typedefs",url:"namespacemembers_type.html"}]}]}, @@ -78,6 +87,7 @@ var menudata={children:[ {text:"v",url:"functions_v.html#index_v"}, {text:"w",url:"functions_w.html#index_w"}, {text:"y",url:"functions_y.html#index_y"}, +{text:"z",url:"functions_z.html#index_z"}, {text:"~",url:"functions_0x7e.html#index_0x7e"}]}, {text:"Functions",url:"functions_func.html",children:[ {text:"a",url:"functions_func.html#index_a"}, @@ -89,8 +99,10 @@ var menudata={children:[ {text:"g",url:"functions_func_g.html#index_g"}, {text:"h",url:"functions_func_h.html#index_h"}, {text:"i",url:"functions_func_i.html#index_i"}, +{text:"k",url:"functions_func_k.html#index_k"}, {text:"l",url:"functions_func_l.html#index_l"}, {text:"m",url:"functions_func_m.html#index_m"}, +{text:"n",url:"functions_func_n.html#index_n"}, {text:"o",url:"functions_func_o.html#index_o"}, {text:"p",url:"functions_func_p.html#index_p"}, {text:"r",url:"functions_func_r.html#index_r"}, @@ -99,6 +111,7 @@ var menudata={children:[ {text:"u",url:"functions_func_u.html#index_u"}, {text:"v",url:"functions_func_v.html#index_v"}, {text:"w",url:"functions_func_w.html#index_w"}, +{text:"z",url:"functions_func_z.html#index_z"}, {text:"~",url:"functions_func_0x7e.html#index_0x7e"}]}, {text:"Variables",url:"functions_vars.html",children:[ {text:"a",url:"functions_vars.html#index_a"}, @@ -113,6 +126,7 @@ var menudata={children:[ {text:"l",url:"functions_vars_l.html#index_l"}, {text:"m",url:"functions_vars_m.html#index_m"}, {text:"n",url:"functions_vars_n.html#index_n"}, +{text:"o",url:"functions_vars_o.html#index_o"}, {text:"p",url:"functions_vars_p.html#index_p"}, {text:"r",url:"functions_vars_r.html#index_r"}, {text:"s",url:"functions_vars_s.html#index_s"}, @@ -127,6 +141,7 @@ var menudata={children:[ {text:"f",url:"functions_type_f.html#index_f"}, {text:"g",url:"functions_type_g.html#index_g"}, {text:"i",url:"functions_type_i.html#index_i"}, +{text:"k",url:"functions_type_k.html#index_k"}, {text:"l",url:"functions_type_l.html#index_l"}, {text:"m",url:"functions_type_m.html#index_m"}, {text:"n",url:"functions_type_n.html#index_n"}, @@ -140,8 +155,10 @@ var menudata={children:[ {text:"Enumerations",url:"functions_enum.html"}, {text:"Enumerator",url:"functions_eval.html",children:[ {text:"a",url:"functions_eval.html#index_a"}, +{text:"b",url:"functions_eval.html#index_b"}, {text:"k",url:"functions_eval.html#index_k"}, {text:"m",url:"functions_eval.html#index_m"}, +{text:"o",url:"functions_eval.html#index_o"}, {text:"v",url:"functions_eval.html#index_v"}]}]}]}, {text:"Files",url:"files.html",children:[ {text:"File List",url:"files.html"}, diff --git a/docs/modules.html b/docs/modules.html index c42247bd4..8fc908440 100644 --- a/docs/modules.html +++ b/docs/modules.html @@ -76,19 +76,20 @@ $(function() {

    Namespaces

    - - - - - - - + + + + + + + +
     Fragment Concept
     Fragment Iterator Concept
     Predicate Vector Concept
     Predicate Iterator Concept
     Predicate Tile Adapter Concept
     Layout Concept
     Tile Traits Concept
     Tile Load Iterator Concept
     Tile Store Iterator Concept
     Identity Block Swizzle
     Predicate Vector Concept
     Predicate Iterator Concept
     Predicate Tile Adapter Concept
     Layout Concept
     Tile Traits Concept
     Tile Load Iterator Concept
     Tile Store Iterator Concept
    diff --git a/docs/namespacecutlass.html b/docs/namespacecutlass.html index 989135cba..4fb1ce9a2 100644 --- a/docs/namespacecutlass.html +++ b/docs/namespacecutlass.html @@ -79,6 +79,8 @@ $(function() { + + @@ -88,24 +90,14 @@ Namespaces Classes + + - - - - - - - - - - - - @@ -129,6 +121,8 @@ Classes + + @@ -143,44 +137,50 @@ Classes + + + - - - - - - - - - - - - + + + + - - + + - + - + + - + - + + + + + + + + + + + @@ -190,22 +190,43 @@ Classes + + - + + + + + + + + + + + + + + + + + + + + @@ -216,6 +237,8 @@ Classes + + @@ -240,19 +263,34 @@ Classes - + - + - + - + + + + + - + + + + + + + - + + + + + + @@ -263,9 +301,15 @@ Classes + + + + + + @@ -282,13 +326,30 @@ Classes + + + + + + + + + + + + + - + + + + + @@ -299,6 +360,20 @@ Classes + + + + + + + + + + + + + +

    Namespaces

     detail
     
     gemm
     
     platform
    struct  AlignedStruct
     
    struct  bin1_t
     
    struct  ComputeOffsetFromShape
     Compute the offset for the given coordinates in a cube. More...
     
    struct  ComputeOffsetFromShape< Shape< 1, kSh_, kSw_, 1 > >
     Compute the offset for the given coordinates in a cube with one channel and a depth of 1. More...
     
    struct  ComputeOffsetFromShape< Shape< 1, kSh_, kSw_, kSc_ > >
     Compute the offset for the given coordinates in a cube with a depth of 1. More...
     
    struct  ComputeOffsetFromStrides
     Compute the offset for the given coordinates in a cube. More...
     
    struct  ComputeOffsetFromStrides< Shape< 1, S_h_, S_w_, 1 > >
     Compute the offset for the given coordinates in a cube with one channel and a depth of 1. More...
     
    struct  ComputeOffsetFromStrides< Shape< 1, S_h_, S_w_, S_c_ > >
     Compute the offset for the given coordinates in a cube with a depth of 1. More...
     
    struct  ComputeThreadOffsetFromStrides
     Decompose threadId.x into coordinate of a cube whose dimensions are specified by Threads_. Afterwards compute the offset of those coordinates using Strides_. More...
     
     
    struct  divide_assert
     
    struct  DumpType
     
    struct  Extent
     Returns the extent of a scalar or vector. More...
     
     
    struct  FragmentConstIterator
     
    struct  FragmentElementType
     Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix. More...
     
    struct  FragmentIterator
     A template defining Fragment Iterator Concept. More...
     
    struct  FragmentLoad
     
    struct  FragmentLoad< IteratorFragment::kScalar, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
     
    struct  FragmentLoad< IteratorFragment::kWmmaMatrix, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
     
    struct  FragmentStore
     
    struct  FragmentStore< IteratorFragment::kScalar, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
     
    struct  FragmentStore< IteratorFragment::kWmmaMatrix, kAccessSize, Scalar_, Memory_, FragmentElement_, kStride >
     
    struct  GemmOperand
     Gemm operand - D = A * B + C. More...
     
    struct  Identity
     Describes identity elements. More...
     
    struct  IdentityTensorMapFunc
     
    struct  int4_t
     
    struct  is_pow2
     
    struct  IteratorAdvance
     Specifies dimension in which post-increment accesses advance. More...
     
    struct  IteratorFragment
     Specifies whether iterator storage fragment consists of Scalar values or WMMA matrix. More...
    struct  KernelLaunchConfiguration
     Structure containing the basic launch configuration of a CUDA kernel. More...
     
    struct  Load
     
    struct  Load< double, 2, Memory_, true, 16 >
    struct  Load< double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16 >
     
    struct  Load< Scalar_, Lanes_, Memory_, true, 16 >
    struct  Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2 >
     Partial specialization for 16b loads. More...
     
    struct  Load< Scalar_, Lanes_, Memory_, true, 4 >
    struct  Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16 >
     
    struct  Load< Scalar_, Lanes_, Memory_, true, 8 >
    struct  Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4 >
     
    struct  Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8 >
     
    struct  Load< Scalar_, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     
    struct  Load< Vector< bin1_t, 32 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     
    struct  Load< Vector< int4_t, 8 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     
    struct  Load< Vector< uint4_t, 8 >, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     
    struct  log2_down
     
     
    struct  log2_up< N, 1, Count >
     
    struct  MatrixCoord
     
    struct  MatrixLayout
     Describes layouts of matrices. More...
     Defines data layouts of various matrix formats usable by TensorRef and other classes. More...
     
    struct  MatrixTransform
     Transformation applied to matrix operands. More...
     
    struct  Max
     
    struct  MemorySpace
     Enum to specify which memory space data resides in. More...
     
    struct  Min
     
    struct  PredicatedTileLoadStream
     Generic stream for loading and transforming fragments. More...
     
    struct  PredicatedTileStoreStream
     Generic stream for transforming and storing fragments. More...
     
    struct  PredicateTileAdapter
     Adapter to enable random access to predicates via logical coordinate within a tile. More...
     
    struct  PredicateVector
     Statically sized array of bits implementing. More...
     
    struct  RegularTilePredicateFunctor
     Functor computing a predicate given the logical position of an access. More...
     
    struct  ReshapeTile
     
    struct  ReshapeTile< Tile_, kAccessSize_, true >
     
    struct  ScalarIO
     Helper to enable formatted printing of CUTLASS scalar types to an ostream. More...
     
    struct  Shape
     A Shape implementing Layout Concept describing the dimensions of a cube. More...
     
     
    struct  ShapeDiv
     
    struct  ShapeDivCeiling
     
    struct  ShapeMax
     
    struct  ShapeMin
     
    struct  Store
     
    struct  Store< double, 2, Memory_, true, 16 >
    struct  Store< double, 2, Memory_, FragmentElementType::kScalar, double, kStride, 16 >
     
    struct  Store< Scalar_, Lanes_, Memory_, true, 16 >
    struct  Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, 1, 2 >
     
    struct  Store< Scalar_, Lanes_, Memory_, true, 4 >
    struct  Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 16 >
     
    struct  Store< Scalar_, Lanes_, Memory_, true, 8 >
    struct  Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 4 >
     
    struct  Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kScalar, Scalar_, kStride, 8 >
     
    struct  Store< Scalar_, kAccessSize, Memory_, FragmentElementType::kWmmaMatrix, FragmentElement_, kStride, size >
     
    class  TensorRef
     Structure modeling a pointer and stride into a tensor. More...
     
    class  TensorRef< Storage_, Rank_, MapFunc_, 1, Index_, LongIndex_ >
     Specialization for rank=1 case with no internal StrideVector. More...
     
    struct  TensorRefArray
     
    struct  TensorRefBatchStrided
     
    class  TensorView
     Host-side reference implementation of tensor operations. More...
     Defines a view into a logical tensor. More...
     
    struct  TileAllocation
     Class for storing a tile in memory and accessing it through a tensor ref. More...
     
    struct  TileCoord
     
    struct  TiledThreadOffset
     Basic thread offset function computed from a thread shape. More...
    struct  TileLoadIterator
     An iterator implementing Tile Load Iterator Concept for loading a tile from memory. More...
     
    struct  TileLoadStream
     Generic stream for loading and transforming fragments. More...
     
    struct  TileStoreIterator
     An iterator implementing Tile Store Iterator Concept for storing a tile to memory. More...
     
    struct  TileStoreStream
     Generic stream for transforming and storing fragments. More...
     
    struct  TileTraits
     A template defining Tile Traits Concept. More...
     
    struct  TrivialPredicateTileAdapter
     Always returns true predicate. More...
     
    struct  uint4_t
     
    union  Vector
     
    union  Vector< bin1_t, kLanes_ >
     Vector definition for 1-bit binary datatype. More...
     
    union  Vector< half, 1 >
     
    union  Vector< half, kLanes_ >
     
    union  Vector< int4_t, kLanes_ >
     Vector definition for 4-bit signed integer datatype. More...
     
    union  Vector< uint4_t, kLanes_ >
     Vector definition for 4-bit unsigned integer datatype. More...
     
    struct  Vectorize
     
    struct  Vectorize< Element_, 1 >
    struct  Vectorize< Vector< bin1_t, 32 >, kLanes_ >
     
    struct  Vectorize< Vector< int4_t, 8 >, kLanes_ >
     
    struct  Vectorize< Vector< uint4_t, 8 >, kLanes_ >
     
    struct  VectorTraits
     Traits describing properties of vectors and scalar-as-vectors. More...
    struct  VectorTraits< Vector< T, Lanes > const >
     Partial specialization for actual cutlass::Vector. More...
     
    struct  ZipConvert
     Zips two convert operations. More...
     
    struct  ZipFragment
     A template defining Fragment Concept. More...
     
    struct  ZipTensorRef
     
    struct  ZipTileAllocation
     Manages a pair of tile allocations as if they are one allocation. More...
     
    class  ZipTileIterator
     Constructs an iterator from a pair of iterators. More...
     
    @@ -314,85 +389,42 @@ Functions - - - - - - - - - - - - + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - @@ -402,6 +434,12 @@ Functions + + + + + + @@ -426,12 +464,24 @@ Functions - - - - - - + + + + + + + + + + + + + + + + + +

    Functions

    CUTLASS_HOST_DEVICE Coord< 4 > make_Coord (int _0, int _1, int _2, int _3)
     Helper to make a 4-element coordinate. More...
     
    CUTLASS_HOST_DEVICE Coord< 2 > get_Coord_hw (Coord< 3 > const &coord)
     Getter. More...
     
    CUTLASS_HOST_DEVICE Coord< 2 > get_Coord_hw (Coord< 4 > const &coord)
     Getter. More...
     
    CUTLASS_HOST_DEVICE Coord< 3 > get_Coord_hwc (Coord< 4 > const &coord)
     Getter. More...
     
    CUTLASS_HOST_DEVICE Coord< 3 > get_Coord_dhw (Coord< 4 > const &coord)
     Getter. More...
     
    template<typename Shape_ >
    CUTLASS_HOST_DEVICE Coord< 3 > make_Coord_from_shape ()
     
    template<int Rank>
    std::ostream & operator<< (std::ostream &out, Coord< Rank > const &coord)
     
    template<typename T >
    std::ostream & operator<< (std::ostream &out, ScalarIO< T > const &scalar)
     Default printing to ostream. More...
     
    template<>
    std::ostream & operator<< (std::ostream &out, ScalarIO< int8_t > const &scalar)
     Printing to ostream of int8_t as integer rather than character. More...
     
    template<>
    std::ostream & operator<< (std::ostream &out, ScalarIO< uint8_t > const &scalar)
     Printing to ostream of uint8_t as integer rather than character. More...
     
    template<>
    std::ostream & operator<< (std::ostream &out, ScalarIO< cutlass::Vector< cutlass::bin1_t, 32 > > const &scalar)
     Printing to ostream of vector of 1b elements. More...
     
    template<>
    std::ostream & operator<< (std::ostream &out, ScalarIO< cutlass::Vector< cutlass::int4_t, 8 > > const &scalar)
     Printing to ostream of vector of 4b signed integer elements. More...
     
    template<>
    std::ostream & operator<< (std::ostream &out, ScalarIO< cutlass::Vector< cutlass::uint4_t, 8 > > const &scalar)
     Printing to ostream of vector of 4b unsigned integer elements. More...
     
    template<typename InputIterator , typename Fragment >
    CUTLASS_HOST_DEVICE void iterator_load (InputIterator &iterator, Fragment &fragment)
     Loads a fragment from an input iterator. More...
     
    template<typename InputIterator , typename Fragment >
    CUTLASS_DEVICE void shared_iterator_load (InputIterator &iterator, Fragment &fragment)
     Loads a fragment from a shared memory input iterator. More...
     
    template<typename InputIterator , typename Fragment >
    CUTLASS_DEVICE void shared_iterator_load (InputIterator &iterator, Fragment &fragment, int d)
     Loads a fragment from a shared memory input iterator. More...
     
    template<typename InputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void iterator_load_post_increment (InputIterator &iterator, Fragment &fragment, typename InputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
     Loads a fragment from an input iterator, masked by a predicate iterator. More...
     
    template<typename InputIterator , typename Fragment >
    CUTLASS_HOST_DEVICE void iterator_load_post_increment (InputIterator &iterator, Fragment &fragment, typename InputIterator::Index offset=0)
     Loads a fragment from an input iterator. More...
     
    template<typename InputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void iterator_load_post_increment (InputIterator &iterator, Fragment &fragment, ConstPredicateAdapter pred_it)
     Loads a fragment from an input iterator. More...
     
    template<typename InputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void iterator_load (InputIterator const &_iterator, Fragment &fragment, typename InputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
     
    template<typename InputIterator , typename Fragment >
    CUTLASS_HOST_DEVICE void iterator_load (InputIterator const &iterator, Fragment &fragment, typename InputIterator::Index offset=0)
     Loads a fragment from an input iterator. More...
     
    template<typename InputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void iterator_load (InputIterator const &iterator, Fragment &fragment, ConstPredicateAdapter pred_it)
     Loads a fragment from an input iterator. More...
     
    template<typename OutputIterator , typename Fragment >
    CUTLASS_HOST_DEVICE void iterator_store (OutputIterator &iterator, Fragment &fragment)
     Stores a fragment to an output iterator. More...
     
    template<typename OutputIterator , typename Fragment >
    CUTLASS_DEVICE void shared_iterator_store (OutputIterator &iterator, Fragment const &fragment)
     Stores a fragment to a shared memory output iterator. More...
     
    template<typename OutputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void iterator_store_post_increment (OutputIterator &iterator, Fragment const &fragment, typename OutputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
     Stores a fragment to an output iterator, masked by a predicate iterator. More...
     
    template<typename OutputIterator , typename Fragment >
    CUTLASS_HOST_DEVICE void iterator_store_post_increment (OutputIterator &iterator, Fragment const &fragment, typename OutputIterator::Index offset=0)
     Stores a fragment to an output iterator. More...
     
    template<typename OutputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void iterator_store_post_increment (OutputIterator &iterator, Fragment const &fragment, ConstPredicateAdapter pred_it)
     Stores a fragment to an output iterator. More...
     
    template<typename OutputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void iterator_store (OutputIterator const &_iterator, Fragment const &fragment, typename OutputIterator::Index offset, ConstPredicateAdapter predicate_adapter)
     Stores a fragment to an output iterator, masked by a predicate iterator. More...
     
    template<typename OutputIterator , typename Fragment >
    CUTLASS_HOST_DEVICE void iterator_store (OutputIterator const &iterator, Fragment const &fragment, typename OutputIterator::Index offset=0)
     Stores a fragment to an output iterator. More...
     
    template<typename OutputIterator , typename Fragment , typename ConstPredicateAdapter >
    CUTLASS_HOST_DEVICE void iterator_store (OutputIterator const &iterator, Fragment const &fragment, ConstPredicateAdapter pred_it)
     Stores a fragment to an output iterator. More...
     
    template<typename dividend_t , typename divisor_t >
    CUTLASS_HOST_DEVICE dividend_t round_nearest (dividend_t dividend, divisor_t divisor)
     
    template<typename value_t >
    CUTLASS_HOST_DEVICE value_t lcm (value_t a, value_t b)
     
    template<typename value_t >
    CUTLASS_HOST_DEVICE value_t clz (value_t x)
     
    template<typename value_t >
    CUTLASS_HOST_DEVICE value_t find_log2 (value_t x)
     
    __host__ CUTLASS_DEVICE cudaError_t cuda_perror_impl (cudaError_t error, const char *filename, int line)
     The corresponding error message is printed to stderr (or stdout in device code) along with the supplied source context. More...
     
    template<>
    struct __align__ (64) AlignedStruct< 64 >
     
    template<typename Scalar_ >
    CUTLASS_DEVICE void make_zero (Scalar_ &x)
     
    template<typename Scalar_ , int kLanes_>
    CUTLASS_DEVICE void make_zero (Vector< Scalar_, kLanes_ > &vec)
     
    template<typename Scalar_ >
    CUTLASS_HOST_DEVICE void make_zero (Scalar_ &x)
     
    template<typename Scalar_ , int kLanes_>
    CUTLASS_HOST_DEVICE void make_zero (Vector< Scalar_, kLanes_ > &vec)
     
    template<typename First , typename Second >
    CUTLASS_HOST_DEVICE ZipFragment< First, Second > make_ZipFragment (First const &first, Second const &second)
     Helper to construct a ZipFragment object. More...
     
    template<typename First , typename Second >
    CUTLASS_HOST_DEVICE ZipConvert< First, Second > make_ZipConvert (First const &first, Second const &second)
     Helper to construct a ZipConvert object. More...
     
    template<typename First , typename Second >
    CUTLASS_HOST_DEVICE ZipTensorRef< First, Second > make_ZipTensorRef (First const &first, Second const &second)
     Constructs a ZipTensorRef. More...
     

    Function Documentation

    @@ -452,10 +502,30 @@ template<>
    +
    + + +

    ◆ __align__() [2/7]

    + +
    +
    +
    +template<>
    + + + + + + + + +
    struct cutlass::__align__ ()
    +
    +
    -

    ◆ __align__() [2/7]

    +

    ◆ __align__() [3/7]

    @@ -475,7 +545,7 @@ template<>
    -

    ◆ __align__() [3/7]

    +

    ◆ __align__() [4/7]

    @@ -495,7 +565,7 @@ template<>
    -

    ◆ __align__() [4/7]

    +

    ◆ __align__() [5/7]

    @@ -515,7 +585,7 @@ template<>
    -

    ◆ __align__() [5/7]

    +

    ◆ __align__() [6/7]

    @@ -535,7 +605,7 @@ template<>
    -

    ◆ __align__() [6/7]

    +

    ◆ __align__() [7/7]

    @@ -554,23 +624,24 @@ template<>
    - -

    ◆ __align__() [7/7]

    + +

    ◆ clz()

    -template<>
    +template<typename value_t >
    - + - - + +
    struct cutlass::__align__ CUTLASS_HOST_DEVICE value_t cutlass::clz ()value_t x)
    +

    log2 computation, what's the difference between the below codes and log2_up/down codes?

    @@ -607,6 +678,26 @@ template<>
    Returns
    The CUDA error.
    +
    + + +

    ◆ find_log2()

    + +
    +
    +
    +template<typename value_t >
    + + + + + + + + +
    CUTLASS_HOST_DEVICE value_t cutlass::find_log2 (value_t x)
    +
    +
    @@ -638,82 +729,10 @@ template<typename value_t >

    Greatest common divisor

    -
    - - -

    ◆ get_Coord_dhw()

    - -
    -
    - - - - - - - - -
    CUTLASS_HOST_DEVICE Coord<3> cutlass::get_Coord_dhw (Coord< 4 > const & coord)
    -
    - -
    -
    - -

    ◆ get_Coord_hw() [1/2]

    - -
    -
    - - - - - - - - -
    CUTLASS_HOST_DEVICE Coord<2> cutlass::get_Coord_hw (Coord< 3 > const & coord)
    -
    - -
    -
    - -

    ◆ get_Coord_hw() [2/2]

    - -
    -
    - - - - - - - - -
    CUTLASS_HOST_DEVICE Coord<2> cutlass::get_Coord_hw (Coord< 4 > const & coord)
    -
    - -
    -
    - -

    ◆ get_Coord_hwc()

    - -
    -
    - - - - - - - - -
    CUTLASS_HOST_DEVICE Coord<3> cutlass::get_Coord_hwc (Coord< 4 > const & coord)
    -
    -
    -

    ◆ iterator_load() [1/4]

    +

    ◆ iterator_load()

    @@ -740,238 +759,10 @@ template<typename InputIterator , typename Fragment >
    -
    - - -

    ◆ iterator_load() [2/4]

    - -
    -
    -
    -template<typename InputIterator , typename Fragment , typename ConstPredicateAdapter >
    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_HOST_DEVICE void cutlass::iterator_load (InputIterator const & _iterator,
    Fragmentfragment,
    typename InputIterator::Index offset,
    ConstPredicateAdapter predicate_adapter 
    )
    -
    - -
    -
    - -

    ◆ iterator_load() [3/4]

    - -
    -
    -
    -template<typename InputIterator , typename Fragment >
    - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_HOST_DEVICE void cutlass::iterator_load (InputIterator const & iterator,
    Fragmentfragment,
    typename InputIterator::Index offset = 0 
    )
    -
    - -
    -
    - -

    ◆ iterator_load() [4/4]

    - -
    -
    -
    -template<typename InputIterator , typename Fragment , typename ConstPredicateAdapter >
    - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_HOST_DEVICE void cutlass::iterator_load (InputIterator const & iterator,
    Fragmentfragment,
    ConstPredicateAdapter pred_it 
    )
    -
    - -
    -
    - -

    ◆ iterator_load_post_increment() [1/3]

    - -
    -
    -
    -template<typename InputIterator , typename Fragment , typename ConstPredicateAdapter >
    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_HOST_DEVICE void cutlass::iterator_load_post_increment (InputIterator & iterator,
    Fragmentfragment,
    typename InputIterator::Index offset,
    ConstPredicateAdapter predicate_adapter 
    )
    -
    - -
    -
    - -

    ◆ iterator_load_post_increment() [2/3]

    - -
    -
    -
    -template<typename InputIterator , typename Fragment >
    - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_HOST_DEVICE void cutlass::iterator_load_post_increment (InputIterator & iterator,
    Fragmentfragment,
    typename InputIterator::Index offset = 0 
    )
    -
    - -
    -
    - -

    ◆ iterator_load_post_increment() [3/3]

    - -
    -
    -
    -template<typename InputIterator , typename Fragment , typename ConstPredicateAdapter >
    - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_HOST_DEVICE void cutlass::iterator_load_post_increment (InputIterator & iterator,
    Fragmentfragment,
    ConstPredicateAdapter pred_it 
    )
    -
    -
    -

    ◆ iterator_store() [1/4]

    +

    ◆ iterator_store()

    @@ -998,234 +789,6 @@ template<typename OutputIterator , typename Fragment >
    -
    - - -

    ◆ iterator_store() [2/4]

    - -
    -
    -
    -template<typename OutputIterator , typename Fragment , typename ConstPredicateAdapter >
    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_HOST_DEVICE void cutlass::iterator_store (OutputIterator const & _iterator,
    Fragment const & fragment,
    typename OutputIterator::Index offset,
    ConstPredicateAdapter predicate_adapter 
    )
    -
    - -
    -
    - -

    ◆ iterator_store() [3/4]

    - -
    -
    -
    -template<typename OutputIterator , typename Fragment >
    - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_HOST_DEVICE void cutlass::iterator_store (OutputIterator const & iterator,
    Fragment const & fragment,
    typename OutputIterator::Index offset = 0 
    )
    -
    - -
    -
    - -

    ◆ iterator_store() [4/4]

    - -
    -
    -
    -template<typename OutputIterator , typename Fragment , typename ConstPredicateAdapter >
    - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_HOST_DEVICE void cutlass::iterator_store (OutputIterator const & iterator,
    Fragment const & fragment,
    ConstPredicateAdapter pred_it 
    )
    -
    - -
    -
    - -

    ◆ iterator_store_post_increment() [1/3]

    - -
    -
    -
    -template<typename OutputIterator , typename Fragment , typename ConstPredicateAdapter >
    - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_HOST_DEVICE void cutlass::iterator_store_post_increment (OutputIterator & iterator,
    Fragment const & fragment,
    typename OutputIterator::Index offset,
    ConstPredicateAdapter predicate_adapter 
    )
    -
    - -
    -
    - -

    ◆ iterator_store_post_increment() [2/3]

    - -
    -
    -
    -template<typename OutputIterator , typename Fragment >
    - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_HOST_DEVICE void cutlass::iterator_store_post_increment (OutputIterator & iterator,
    Fragment const & fragment,
    typename OutputIterator::Index offset = 0 
    )
    -
    - -
    -
    - -

    ◆ iterator_store_post_increment() [3/3]

    - -
    -
    -
    -template<typename OutputIterator , typename Fragment , typename ConstPredicateAdapter >
    - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_HOST_DEVICE void cutlass::iterator_store_post_increment (OutputIterator & iterator,
    Fragment const & fragment,
    ConstPredicateAdapter pred_it 
    )
    -
    -
    @@ -1379,8 +942,27 @@ template<typename value_t > - -

    ◆ make_zero() [1/2]

    + +

    ◆ make_Coord_from_shape()

    + +
    +
    +
    +template<typename Shape_ >
    + + + + + + + +
    CUTLASS_HOST_DEVICE Coord<3> cutlass::make_Coord_from_shape ()
    +
    + +
    +
    + +

    ◆ make_zero() [1/2]

    @@ -1388,7 +970,7 @@ template<typename value_t >
    template<typename Scalar_ >
    - + @@ -1399,8 +981,8 @@ template<typename Scalar_ > - -

    ◆ make_zero() [2/2]

    + +

    ◆ make_zero() [2/2]

    @@ -1408,7 +990,7 @@ template<typename Scalar_ >
    template<typename Scalar_ , int kLanes_>
    CUTLASS_DEVICE void cutlass::make_zero CUTLASS_HOST_DEVICE void cutlass::make_zero ( Scalar_ &  x)
    - + @@ -1417,6 +999,354 @@ template<typename Scalar_ , int kLanes_>
    CUTLASS_DEVICE void cutlass::make_zero CUTLASS_HOST_DEVICE void cutlass::make_zero ( Vector< Scalar_, kLanes_ > &  vec)
    +
    + + +

    ◆ make_ZipConvert()

    + +
    +
    +
    +template<typename First , typename Second >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE ZipConvert<First, Second> cutlass::make_ZipConvert (First const & first,
    Second const & second 
    )
    +
    + +
    +
    + +

    ◆ make_ZipFragment()

    + +
    +
    +
    +template<typename First , typename Second >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE ZipFragment<First, Second> cutlass::make_ZipFragment (First const & first,
    Second const & second 
    )
    +
    + +
    +
    + +

    ◆ make_ZipTensorRef()

    + +
    +
    +
    +template<typename First , typename Second >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE ZipTensorRef<First, Second> cutlass::make_ZipTensorRef (First const & first,
    Second const & second 
    )
    +
    + +
    +
    + +

    ◆ operator<<() [1/7]

    + +
    +
    +
    +template<int Rank>
    + + + + + + + + + + + + + + + + + + +
    std::ostream& cutlass::operator<< (std::ostream & out,
    Coord< Rank > const & coord 
    )
    +
    + +
    +
    + +

    ◆ operator<<() [2/7]

    + +
    +
    +
    +template<typename T >
    + + + + + +
    + + + + + + + + + + + + + + + + + + +
    std::ostream& cutlass::operator<< (std::ostream & out,
    ScalarIO< T > const & scalar 
    )
    +
    +inline
    +
    + +
    +
    + +

    ◆ operator<<() [3/7]

    + +
    +
    +
    +template<>
    + + + + + +
    + + + + + + + + + + + + + + + + + + +
    std::ostream& cutlass::operator<< (std::ostream & out,
    ScalarIO< int8_t > const & scalar 
    )
    +
    +inline
    +
    + +
    +
    + +

    ◆ operator<<() [4/7]

    + +
    +
    +
    +template<>
    + + + + + +
    + + + + + + + + + + + + + + + + + + +
    std::ostream& cutlass::operator<< (std::ostream & out,
    ScalarIO< uint8_t > const & scalar 
    )
    +
    +inline
    +
    + +
    +
    + +

    ◆ operator<<() [5/7]

    + +
    +
    +
    +template<>
    + + + + + +
    + + + + + + + + + + + + + + + + + + +
    std::ostream& cutlass::operator<< (std::ostream & out,
    ScalarIO< cutlass::Vector< cutlass::bin1_t, 32 > > const & scalar 
    )
    +
    +inline
    +
    + +
    +
    + +

    ◆ operator<<() [6/7]

    + +
    +
    +
    +template<>
    + + + + + +
    + + + + + + + + + + + + + + + + + + +
    std::ostream& cutlass::operator<< (std::ostream & out,
    ScalarIO< cutlass::Vector< cutlass::int4_t, 8 > > const & scalar 
    )
    +
    +inline
    +
    + +
    +
    + +

    ◆ operator<<() [7/7]

    + +
    +
    +
    +template<>
    + + + + + +
    + + + + + + + + + + + + + + + + + + +
    std::ostream& cutlass::operator<< (std::ostream & out,
    ScalarIO< cutlass::Vector< cutlass::uint4_t, 8 > > const & scalar 
    )
    +
    +inline
    +
    +
    @@ -1448,108 +1378,12 @@ template<typename dividend_t , typename divisor_t >

    Round dividend up to the nearest multiple of divisor

    -
    - - -

    ◆ shared_iterator_load() [1/2]

    - -
    -
    -
    -template<typename InputIterator , typename Fragment >
    - - - - - - - - - - - - - - - - - - -
    CUTLASS_DEVICE void cutlass::shared_iterator_load (InputIterator & iterator,
    Fragmentfragment 
    )
    -
    - -
    -
    - -

    ◆ shared_iterator_load() [2/2]

    - -
    -
    -
    -template<typename InputIterator , typename Fragment >
    - - - - - - - - - - - - - - - - - - - - - - - - -
    CUTLASS_DEVICE void cutlass::shared_iterator_load (InputIterator & iterator,
    Fragmentfragment,
    int d 
    )
    -
    - -
    -
    - -

    ◆ shared_iterator_store()

    - -
    -
    -
    -template<typename OutputIterator , typename Fragment >
    - - - - - - - - - - - - - - - - - - -
    CUTLASS_DEVICE void cutlass::shared_iterator_store (OutputIterator & iterator,
    Fragment const & fragment 
    )
    -
    -
    diff --git a/docs/namespacecutlass_1_1detail.html b/docs/namespacecutlass_1_1detail.html new file mode 100644 index 000000000..154ce5c45 --- /dev/null +++ b/docs/namespacecutlass_1_1detail.html @@ -0,0 +1,95 @@ + + + + + + + +Cutlass: cutlass::detail Namespace Reference + + + + + + + + + + +
    +
    + + + + + + +
    +
    Cutlass +
    +
    CUDA Templates for Linear Algebra Subroutines and Solvers
    +
    +
    + + + + + + + + +
    +
    + + +
    + +
    + + +
    +
    + +
    +
    cutlass::detail Namespace Reference
    +
    +
    + + + + +

    +Classes

    class  ScalarOrPointer
     
    +
    + + + + diff --git a/docs/namespacecutlass_1_1gemm.html b/docs/namespacecutlass_1_1gemm.html index 1c84e4480..1545f43f4 100644 --- a/docs/namespacecutlass_1_1gemm.html +++ b/docs/namespacecutlass_1_1gemm.html @@ -84,19 +84,28 @@ $(function() { Classes struct  ClearAccumulators   +struct  ColumnMajorBlockSwizzle +  struct  DgemmConfig   struct  DgemmTraits   +struct  Fp16SgemmConfig +  +struct  Fp16SgemmSgemmTraits +  struct  FragmentMultiplyAdd   -struct  FragmentMultiplyAdd< half > +struct  FragmentMultiplyAdd< half, half, true >   struct  Gemm   struct  GemmConfig   +struct  GemmCoord +  struct  GemmDesc + GEMM problem description. More...
      struct  GemmEpilogue   @@ -151,7 +160,8 @@ Classes   struct  GlobalLoadStream   -struct  GlobalLoadStreamBase +struct  GlobalLoadStreamPair + Collect the global load streams for multiplicands. More...
      struct  HgemmConfig   @@ -187,9 +197,7 @@ Classes   struct  IgemmConfig   -struct  IgemmConfig< OutputTile_, int8_t, AccumulatorsPerThread_ > -  -struct  IgemmContiguousGlobalTileTraits +struct  IgemmConfig< OutputTile_, int8_t, ThreadGemmShape_ >   struct  IgemmEpilogue   @@ -205,6 +213,8 @@ Classes   struct  IgemmFloatToInt8Converter   +struct  IgemmGlobalIteratorAb +  struct  IgemmGlobalLoadTransformer   struct  IgemmGlobalLoadTransformer< Fragment< int8_t, kElements_ >, float > @@ -213,6 +223,8 @@ Classes   struct  IgemmGlobalStoreTransformer< float, Fragment< int8_t, kElements_ > >   +struct  IgemmGlobalTileTraits +  struct  IgemmInt8ToFloatConverter   struct  IgemmSharedStoreTransformer @@ -221,11 +233,15 @@ Classes   struct  IgemmTileTraitsHelperA   -struct  IgemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_ > +struct  IgemmTileTraitsHelperA< MatrixLayout::kColumnMajor, GemmConfig_, Index_ > +  +struct  IgemmTileTraitsHelperA< MatrixLayout::kRowMajor, GemmConfig_, Index_ >   struct  IgemmTileTraitsHelperB   -struct  IgemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_ > +struct  IgemmTileTraitsHelperB< MatrixLayout::kColumnMajor, GemmConfig_, Index_ > +  +struct  IgemmTileTraitsHelperB< MatrixLayout::kRowMajor, GemmConfig_, Index_ >   struct  IgemmTraits   @@ -243,9 +259,17 @@ Classes   struct  IgemmTransformerB< MatrixLayout::kRowMajor, Iterator_ >   +struct  Launch + Partial specialization for launching the GEMM kernel with or without launch bounds. More...
    +  +struct  Launch< Gemm, false > + Partial specialization for launching the GEMM kernel with or without launch bounds. More...
    +  struct  LinearScaling  Functor to compute linear combination of fragments. More...
      +struct  LinearScalingDevicePtr +  struct  ProjectOperand   struct  ProjectOperand< GemmOperand::kA, Kstrided > @@ -264,26 +288,39 @@ Classes   struct  ReshapeThreads< Tile_, Threads_, true >   +struct  RowMajorBlockSwizzle +  struct  SgemmConfig   +struct  SgemmLBTraits + Helper to define SGEMM traits using Launch Bounds. More...
    +  struct  SgemmTraits   struct  SharedLoadStream   +struct  SharedStreamPair + Collect the global load streams for multiplicands. More...
    +  struct  SimplifiedGemmEpilogueTraits   struct  SimplifiedGemmTraits   struct  SimplifiedGemmTraitsHelper   +struct  swizzleDirection +  struct  ThreadMultiplyAdd  Template performing matrix multiply-add operation within a thread. More...
      -struct  ThreadMultiplyAdd< AccumulatorsPerThread_, ThreadsPerWarp_, half, half, half > - Template performing matrix multiply-add operation within a thread. More...
    +struct  ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, float > + Template performing matrix multiply-add operation within a thread. More...
      -struct  ThreadMultiplyAdd< AccumulatorsPerThread_, ThreadsPerWarp_, int8_t, int8_t, int > - Template performing matrix multiply-add operation within a thread. More...
    +struct  ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, half, half, half > + Template performing matrix multiply-add operation within a thread. More...
    +  +struct  ThreadMultiplyAdd< ThreadGemmShape_, ThreadsPerWarp_, int8_t, int8_t, int > + Template performing matrix multiply-add operation within a thread. More...
      struct  WmmaGemmGlobalIteratorCd   @@ -292,18 +329,29 @@ Classes - - - + + + + + + + + + + + + + +

    Functions

    template<typename Gemm_ >
    __global__ void gemm_kernel (typename Gemm_::Params params)
     
    template<typename Gemm_ >
    __global__ __launch_bounds__ (Gemm_::kThreads) void gemm_kernel(typename Gemm_
     GEMM kernel with launch bounds specified. More...
     
    template<typename Gemm_ >
    __global__ void gemm_kernel_nolb (typename Gemm_::Params params)
     GEMM kernel without launch bounds specified. More...
     
    template<typename T >
    CUTLASS_DEVICE bool is_zero (T x)
     
    CUTLASS_DEVICE bool is_zero (half x)
     
    template<enum swizzleDirection::Kind >
    CUTLASS_DEVICE int getLinearIdx (int groups)
     
    template<>
    CUTLASS_DEVICE int getLinearIdx< swizzleDirection::Boustrophedon > (int groups)
     

    Function Documentation

    - -

    ◆ gemm_kernel()

    + +

    ◆ __launch_bounds__()

    @@ -311,7 +359,27 @@ Functions template<typename Gemm_ >
    - + + + + + + +
    __global__ void cutlass::gemm::gemm_kernel __global__ cutlass::gemm::__launch_bounds__ (Gemm_::kThreads )
    +
    + +
    + + +

    ◆ gemm_kernel_nolb()

    + +
    +
    +
    +template<typename Gemm_ >
    + + + @@ -320,6 +388,46 @@ template<typename Gemm_ >
    __global__ void cutlass::gemm::gemm_kernel_nolb ( typename Gemm_::Params  params)
    +
    +
    + +

    ◆ getLinearIdx()

    + +
    +
    +
    +template<enum swizzleDirection::Kind >
    + + + + + + + + +
    CUTLASS_DEVICE int cutlass::gemm::getLinearIdx (int groups)
    +
    + +
    +
    + +

    ◆ getLinearIdx< swizzleDirection::Boustrophedon >()

    + +
    +
    +
    +template<>
    + + + + + + + + +
    CUTLASS_DEVICE int cutlass::gemm::getLinearIdx< swizzleDirection::Boustrophedon > (int groups)
    +
    +
    @@ -363,7 +471,7 @@ template<typename T > diff --git a/docs/namespacecutlass_1_1platform.html b/docs/namespacecutlass_1_1platform.html index 2bf30c0df..b62a896a7 100644 --- a/docs/namespacecutlass_1_1platform.html +++ b/docs/namespacecutlass_1_1platform.html @@ -122,6 +122,8 @@ Classes struct  bool_constant  std::bool_constant More...
      +class  complex +  struct  conditional  std::conditional (true specialization) More...
      @@ -256,6 +258,157 @@ Typedefs + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + @@ -399,8 +552,8 @@ template<> - -

    ◆ __align__() [3/13]

    + +

    ◆ __align__() [3/13]

    @@ -410,7 +563,7 @@ template<>
    - + @@ -419,8 +572,8 @@ template<> - -

    ◆ __align__() [4/13]

    + +

    ◆ __align__() [4/13]

    @@ -430,7 +583,7 @@ template<>
    - + @@ -439,8 +592,8 @@ template<> - -

    ◆ __align__() [5/13]

    + +

    ◆ __align__() [5/13]

    @@ -450,7 +603,7 @@ template<>
    - + @@ -477,10 +630,90 @@ template<>

    Functions

    CUTLASS_HOST_DEVICE float const & real (cuFloatComplex const &z)
     Returns the real part of the complex number. More...
     
    CUTLASS_HOST_DEVICE float & real (cuFloatComplex &z)
     Returns the real part of the complex number. More...
     
    CUTLASS_HOST_DEVICE double const & real (cuDoubleComplex const &z)
     Returns the real part of the complex number. More...
     
    CUTLASS_HOST_DEVICE double & real (cuDoubleComplex &z)
     Returns the real part of the complex number. More...
     
    CUTLASS_HOST_DEVICE float const & imag (cuFloatComplex const &z)
     Returns the imaginary part of the complex number. More...
     
    CUTLASS_HOST_DEVICE float & imag (cuFloatComplex &z)
     Returns the imaginary part of the complex number. More...
     
    CUTLASS_HOST_DEVICE double const & imag (cuDoubleComplex const &z)
     Returns the imaginary part of the complex number. More...
     
    CUTLASS_HOST_DEVICE double & imag (cuDoubleComplex &z)
     Returns the imaginary part of the complex number. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE T const & real (complex< T > const &z)
     Returns the real part of the complex number. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE T & real (complex< T > &z)
     Returns the real part of the complex number. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE T const & imag (complex< T > const &z)
     Returns the imaginary part of the complex number. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE T & imag (complex< T > &z)
     Returns the imaginary part of the complex number. More...
     
    template<typename T >
    std::ostream & operator<< (std::ostream &out, complex< T > const &z)
     
    template<typename T >
    CUTLASS_HOST_DEVICE bool operator== (complex< T > const &lhs, complex< T > const &rhs)
     Equality operator. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE bool operator!= (complex< T > const &lhs, complex< T > const &rhs)
     Inequality operator. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > operator+ (complex< T > const &lhs, complex< T > const &rhs)
     Addition. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > operator- (complex< T > const &lhs, complex< T > const &rhs)
     Subtraction. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > operator* (complex< T > const &lhs, complex< T > const &rhs)
     Multiplication. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > operator* (complex< T > const &lhs, T const &s)
     Scalar Multiplication. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > operator* (T const &s, complex< T > const &rhs)
     Scalar Multiplication. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > operator/ (complex< T > const &lhs, complex< T > const &rhs)
     Division. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > operator/ (complex< T > const &lhs, T const &s)
     Scalar Division. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > operator/ (T const &s, complex< T > const &rhs)
     Scalar divided by complex. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > & operator+= (complex< T > &lhs, complex< T > const &rhs)
     Addition. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > & operator-= (complex< T > &lhs, complex< T > const &rhs)
     Subtraction. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > & operator*= (complex< T > &lhs, complex< T > const &rhs)
     Multiplication. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > & operator*= (complex< T > &lhs, T s)
     Scalar multiplication. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > & operator/= (complex< T > &lhs, complex< T > const &rhs)
     Division. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICEabs (complex< T > const &z)
     Returns the magnitude of the complex number. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICEarg (complex< T > const &z)
     Returns the magnitude of the complex number. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICEnorm (complex< T > const &z)
     Returns the squared magnitude. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > conj (complex< T > const &z)
     Returns the complex conjugate. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > proj (complex< T > const &z)
     Projects the complex number z onto the Riemann sphere. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > polar (T const &r, T const &theta=T())
     Returns a complex number with magnitude r and phase theta. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > exp (complex< T > const &z)
     Computes the complex exponential of z. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > log (complex< T > const &z)
     Computes the complex exponential of z. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > log10 (complex< T > const &z)
     Computes the complex exponential of z. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > sqrt (complex< T > const &z)
     Computes the square root of complex number z. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > cos (complex< T > const &z)
     Computes the cosine of complex z. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE complex< T > sin (complex< T > const &z)
     Computes the sin of complex z. More...
     
    template<typename T >
    CUTLASS_HOST_DEVICE constexpr const T & min (const T &a, const T &b)
     std::min More...
    struct cutlass::platform::__align__ (64  )
    struct cutlass::platform::__align__ (128  )
    struct cutlass::platform::__align__ (256 16  )
    +
    + + +

    ◆ __align__() [7/13]

    + +
    +
    +
    +template<>
    + + + + + + + + +
    struct cutlass::platform::__align__ (32 )
    +
    + +
    +
    + +

    ◆ __align__() [8/13]

    + +
    +
    +
    +template<>
    + + + + + + + + +
    struct cutlass::platform::__align__ (64 )
    +
    + +
    +
    + +

    ◆ __align__() [9/13]

    + +
    +
    +
    +template<>
    + + + + + + + + +
    struct cutlass::platform::__align__ (128 )
    +
    + +
    +
    + +

    ◆ __align__() [10/13]

    + +
    +
    +
    +template<>
    + + + + + + + + +
    struct cutlass::platform::__align__ (256 )
    +
    +
    -

    ◆ __align__() [7/13]

    +

    ◆ __align__() [11/13]

    @@ -500,7 +733,7 @@ template<>
    -

    ◆ __align__() [8/13]

    +

    ◆ __align__() [12/13]

    @@ -520,7 +753,7 @@ template<>
    -

    ◆ __align__() [9/13]

    +

    ◆ __align__() [13/13]

    @@ -539,19 +772,19 @@ template<>
    - -

    ◆ __align__() [10/13]

    + +

    ◆ abs()

    -template<>
    +template<typename T >
    - + - - + +
    struct cutlass::platform::__align__ CUTLASS_HOST_DEVICE T cutlass::platform::abs (32 )complex< T > const & z)
    @@ -559,19 +792,19 @@ template<>
    - -

    ◆ __align__() [11/13]

    + +

    ◆ arg()

    -template<>
    +template<typename T >
    - + - - + +
    struct cutlass::platform::__align__ CUTLASS_HOST_DEVICE T cutlass::platform::arg ()complex< T > const & z)
    @@ -579,19 +812,19 @@ template<>
    - -

    ◆ __align__() [12/13]

    + +

    ◆ conj()

    -template<>
    +template<typename T >
    - + - - + +
    struct cutlass::platform::__align__ CUTLASS_HOST_DEVICE complex<T> cutlass::platform::conj ()complex< T > const & z)
    @@ -599,19 +832,191 @@ template<>
    - -

    ◆ __align__() [13/13]

    + +

    ◆ cos()

    -template<>
    +template<typename T >
    - + - - + + + + +
    struct cutlass::platform::__align__ CUTLASS_HOST_DEVICE complex<T> cutlass::platform::cos (16 )complex< T > const & z)
    +
    + +
    + + +

    ◆ exp()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::exp (complex< T > const & z)
    +
    + +
    +
    + +

    ◆ imag() [1/6]

    + +
    +
    + + + + + + + + +
    CUTLASS_HOST_DEVICE float const& cutlass::platform::imag (cuFloatComplex const & z)
    +
    + +
    +
    + +

    ◆ imag() [2/6]

    + +
    +
    + + + + + + + + +
    CUTLASS_HOST_DEVICE float& cutlass::platform::imag (cuFloatComplex & z)
    +
    + +
    +
    + +

    ◆ imag() [3/6]

    + +
    +
    + + + + + + + + +
    CUTLASS_HOST_DEVICE double const& cutlass::platform::imag (cuDoubleComplex const & z)
    +
    + +
    +
    + +

    ◆ imag() [4/6]

    + +
    +
    + + + + + + + + +
    CUTLASS_HOST_DEVICE double& cutlass::platform::imag (cuDoubleComplex & z)
    +
    + +
    +
    + +

    ◆ imag() [5/6]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + +
    CUTLASS_HOST_DEVICE T const& cutlass::platform::imag (complex< T > const & z)
    +
    + +
    +
    + +

    ◆ imag() [6/6]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + +
    CUTLASS_HOST_DEVICE T& cutlass::platform::imag (complex< T > & z)
    +
    + +
    +
    + +

    ◆ log()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::log (complex< T > const & z)
    +
    + +
    +
    + +

    ◆ log10()

    + +
    +
    +
    +template<typename T >
    + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::log10 (complex< T > const & z)
    @@ -707,10 +1112,30 @@ template<typename T >
    +
    + + +

    ◆ norm()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + +
    CUTLASS_HOST_DEVICE T cutlass::platform::norm (complex< T > const & z)
    +
    +
    -

    ◆ operator!=()

    +

    ◆ operator!=() [1/2]

    @@ -737,6 +1162,426 @@ template<class T1 , class T2 >
    +
    + + +

    ◆ operator!=() [2/2]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE bool cutlass::platform::operator!= (complex< T > const & lhs,
    complex< T > const & rhs 
    )
    +
    + +
    +
    + +

    ◆ operator*() [1/3]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::operator* (complex< T > const & lhs,
    complex< T > const & rhs 
    )
    +
    + +
    +
    + +

    ◆ operator*() [2/3]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::operator* (complex< T > const & lhs,
    T const & s 
    )
    +
    + +
    +
    + +

    ◆ operator*() [3/3]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::operator* (T const & s,
    complex< T > const & rhs 
    )
    +
    + +
    +
    + +

    ◆ operator*=() [1/2]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T>& cutlass::platform::operator*= (complex< T > & lhs,
    complex< T > const & rhs 
    )
    +
    + +
    +
    + +

    ◆ operator*=() [2/2]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T>& cutlass::platform::operator*= (complex< T > & lhs,
    s 
    )
    +
    + +
    +
    + +

    ◆ operator+()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::operator+ (complex< T > const & lhs,
    complex< T > const & rhs 
    )
    +
    + +
    +
    + +

    ◆ operator+=()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T>& cutlass::platform::operator+= (complex< T > & lhs,
    complex< T > const & rhs 
    )
    +
    + +
    +
    + +

    ◆ operator-()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::operator- (complex< T > const & lhs,
    complex< T > const & rhs 
    )
    +
    + +
    +
    + +

    ◆ operator-=()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T>& cutlass::platform::operator-= (complex< T > & lhs,
    complex< T > const & rhs 
    )
    +
    + +
    +
    + +

    ◆ operator/() [1/3]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::operator/ (complex< T > const & lhs,
    complex< T > const & rhs 
    )
    +
    + +
    +
    + +

    ◆ operator/() [2/3]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::operator/ (complex< T > const & lhs,
    T const & s 
    )
    +
    + +
    +
    + +

    ◆ operator/() [3/3]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::operator/ (T const & s,
    complex< T > const & rhs 
    )
    +
    + +
    +
    + +

    ◆ operator/=()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T>& cutlass::platform::operator/= (complex< T > & lhs,
    complex< T > const & rhs 
    )
    +
    +
    @@ -767,6 +1612,36 @@ template<class T1 , class T2 >
    +
    + + +

    ◆ operator<<()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    std::ostream& cutlass::platform::operator<< (std::ostream & out,
    complex< T > const & z 
    )
    +
    +
    @@ -797,10 +1672,40 @@ template<class T1 , class T2 >
    +
    + + +

    ◆ operator==() [1/2]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE bool cutlass::platform::operator== (complex< T > const & lhs,
    complex< T > const & rhs 
    )
    +
    +
    -

    ◆ operator==()

    +

    ◆ operator==() [2/2]

    @@ -887,6 +1792,208 @@ template<class T1 , class T2 >
    +
    + + +

    ◆ polar()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + + + + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::polar (T const & r,
    T const & theta = T() 
    )
    +
    + +
    +
    + +

    ◆ proj()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::proj (complex< T > const & z)
    +
    + +
    +
    + +

    ◆ real() [1/6]

    + +
    +
    + + + + + + + + +
    CUTLASS_HOST_DEVICE float const& cutlass::platform::real (cuFloatComplex const & z)
    +
    + +
    +
    + +

    ◆ real() [2/6]

    + +
    +
    + + + + + + + + +
    CUTLASS_HOST_DEVICE float& cutlass::platform::real (cuFloatComplex & z)
    +
    + +
    +
    + +

    ◆ real() [3/6]

    + +
    +
    + + + + + + + + +
    CUTLASS_HOST_DEVICE double const& cutlass::platform::real (cuDoubleComplex const & z)
    +
    + +
    +
    + +

    ◆ real() [4/6]

    + +
    +
    + + + + + + + + +
    CUTLASS_HOST_DEVICE double& cutlass::platform::real (cuDoubleComplex & z)
    +
    + +
    +
    + +

    ◆ real() [5/6]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + +
    CUTLASS_HOST_DEVICE T const& cutlass::platform::real (complex< T > const & z)
    +
    + +
    +
    + +

    ◆ real() [6/6]

    + +
    +
    +
    +template<typename T >
    + + + + + + + + +
    CUTLASS_HOST_DEVICE T& cutlass::platform::real (complex< T > & z)
    +
    + +
    +
    + +

    ◆ sin()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::sin (complex< T > const & z)
    +
    + +
    +
    + +

    ◆ sqrt()

    + +
    +
    +
    +template<typename T >
    + + + + + + + + +
    CUTLASS_HOST_DEVICE complex<T> cutlass::platform::sqrt (complex< T > const & z)
    +
    +
    @@ -930,7 +2037,7 @@ template<typename T , typename Deleter > diff --git a/docs/namespacemembers.html b/docs/namespacemembers.html index 9566721d1..a522eab71 100644 --- a/docs/namespacemembers.html +++ b/docs/namespacemembers.html @@ -73,22 +73,54 @@ $(function() {

    - _ -

    + + +

    - a -

    - c -

    +

    - e -

    + +

    - f -

    @@ -96,36 +128,30 @@ $(function() {
  • gcd() : cutlass
  • -
  • gemm_kernel() -: cutlass::gemm +
  • gemm_kernel_nolb() +: cutlass::gemm
  • -
  • get_Coord_dhw() -: cutlass +
  • getLinearIdx() +: cutlass::gemm
  • -
  • get_Coord_hw() -: cutlass -
  • -
  • get_Coord_hwc() -: cutlass +
  • getLinearIdx< swizzleDirection::Boustrophedon >() +: cutlass::gemm
  • - i -

    @@ -134,6 +160,12 @@ $(function() {
  • lcm() : cutlass
  • +
  • log() +: cutlass::platform +
  • +
  • log10() +: cutlass::platform +
  • @@ -141,11 +173,23 @@ $(function() {
  • make_Coord() : cutlass
  • +
  • make_Coord_from_shape() +: cutlass +
  • make_pair() : cutlass::platform
  • make_zero() -: cutlass +: cutlass +
  • +
  • make_ZipConvert() +: cutlass +
  • +
  • make_ZipFragment() +: cutlass +
  • +
  • make_ZipTensorRef() +: cutlass
  • max() : cutlass::platform @@ -156,18 +200,53 @@ $(function() { +

    - n -

    + +

    - o -

    +

    - p -

    + +

    - r -

    +

    - n -

    + +

    - o -

    +

    - p -

    + +

    - r -