Committing CUTLASS for release.

This commit is contained in:
akerr
2017-12-04 21:12:52 -08:00
parent bbb3178126
commit d08ba8ac46
35 changed files with 10786 additions and 29 deletions

29
LICENSE
View File

@@ -1,29 +0,0 @@
BSD 3-Clause License
Copyright (c) 2017, NVIDIA Corporation
All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
* Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
* Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

154
cutlass/gemm/block_loader.h Normal file
View File

@@ -0,0 +1,154 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* block-wide tile-loading abstractions
*/
#include "../util/util.h"
namespace cutlass {
namespace gemm {
/******************************************************************************
* load_algorithm
******************************************************************************/
/**
* \brief Enumeration of matrix loading algorithms
*/
struct load_algorithm
{
/// \brief Enumerants. See corresponding tag types.
enum kind_t
{
CongruousCopy = 0,
CrosswiseCopy = 1,
};
/**
* \brief Generic tag
*/
template <kind_t Kind>
struct any_tag : nv_std::integral_constant<kind_t, Kind> {};
/**
* \brief Copy from a global matrix that is row-major in relation
* to the local row-major tile
*/
typedef any_tag<CongruousCopy> contiguous_tag_t;
/**
* \brief Copy from a global matrix that is column-major in relation
* to the local row-major tile
*/
typedef any_tag<CrosswiseCopy> crosswise_tag_t;
};
/******************************************************************************
* block_loader
******************************************************************************/
/**
* \brief A three-phase data loading abstraction (prefetch, commit, and
* advance) for iterating over ranges of block-wide matrix tiles.
*
* Each iteration sequence produces a KxL (height-by-width) block-wide tile of
* value_t in shared memory. The layout of the shared
* block-wide tile is a row-major (L-major) tiling of dp_vector_t items, which are
* themselves column-major (K-major) vectors of value_t. Its dimensions are:
* K = BlockDpVectorsK * (sizeof(dp_vector_t) / sizeof(value_t)
* L = BlockDpVectorsL
*
* NB: This generic class is not directly constructible. Architecture- and
* algorithm-specific template specializations will provide the API
* functionality prescribed here.
*
*/
template <
int BlockThreads, ///< Number of threads in each thread block (blockDim.x)
int BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height)
int BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width)
typename value_t, ///< Input matrix value type
int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension
bool AllowRaggedTiles, ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
typename dp_vector_t, ///< Dot-product vector type along the K-axis
load_algorithm::kind_t LoadAlgorithm> ///< Algorithm for loading a shared tile of KxL matrix data
struct block_loader
{
//-------------------------------------------------------------------------
// Constructor API
//-------------------------------------------------------------------------
/// Constructor
block_loader(
value_t *d_matrix, ///< Pointer to input matrix
int matrix_values_l, ///< Extent of the input matrix in value_t along the L-axis
int matrix_values_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis
int matrix_values_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis
int2 block_begin_item_coords, ///< Thread block's starting value_t coordinates (l, k) within the input matrix
int block_end_item_k); ///< Thread block's ending coordinate (k) within the input matrix (one-past)
//-------------------------------------------------------------------------
// Loader API
//-------------------------------------------------------------------------
/**
* Request the current block-wide tile
*/
void request();
/**
* Advance the loader to the next block-wide tile in the K-axis
*/
void next();
/**
* Commit the previously-requested block-wide tile to shared memory
*
* NB: To facilitate padding for avoiding shared memory bank conflicts, we
* allow the row stride _BlockDpVectorsL to be arbitrarily bigger than the
* tile width BlockDpVectorsL.
*/
template <int _BlockDpVectorsL>
void commit(
dp_vector_t (&scratch_tile)[BlockDpVectorsK][_BlockDpVectorsL]);
};
} // namespace gemm
} // namespace cutlass
/******************************************************************************
* Tail-include specializations that adhere to the block_loader API
******************************************************************************/
#include "block_loader_crosswise.h"
#include "block_loader_congruous_dp1.h"
#include "block_loader_congruous_idp4.h"

View File

@@ -0,0 +1,398 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Tile-loading abstraction for thread blocks
*/
#include "../util/util.h"
namespace cutlass {
namespace gemm {
/******************************************************************************
* block_loader (CongruousCopy + dp1 specialization)
******************************************************************************/
/**
* \brief A three-phase data loading abstraction (prefetch, commit, and
* advance) for iterating over ranges of block-wide matrix tiles.
* (CongruousCopy + dp1 specialization)
*
* Each iteration sequence produces a KxL (height-by-width) block-wide tile of
* value_t in shared memory. The layout of the shared block-wide tile is
* a row-major (L-major) tiling of singleton "dp1" dp_vector_t items, where
* dp_vector_t == value_t. Its dimensions are:
* K = BlockDpVectorsK
* L = BlockDpVectorsL
*
* The data is copied from a corresponding tile of global matrix data whose
* layout of value_t is also L-major. This constitutes a CongruousCopy
* between the L-major global tile and the L-major shared tile.
*
* NB: Because they are "dp1" singletons, the K-major orientation of
* dp_vector_t in shared memory is irrelevant, and the L-major global and
* shared tile layouts are perfectly congruous. As a result, we can increase
* the granularity of data transfer via vectorization of loads and stores
* without any intermediate {dis|re}assembly.
*
* NB: Consecutive threads within a block are mapped in L-major
* fashion across a first-set of LDG-vectors of dp_vector_t (value_t) within
* their global tile. Successive sets of LDG-vectors are then strip-mined
* as necessary down the K-axis. These discontiguous LDG-vectors comprise the
* thread's "slice" of the block-wide tile.
*/
template <
int BlockThreads, ///< Number of threads in each thread block (blockDim.x)
int BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height)
int BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width)
typename value_t, ///< Input matrix value type
int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension
bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
>
struct block_loader<
BlockThreads,
BlockDpVectorsK,
BlockDpVectorsL,
value_t,
LeadingDimAlignBytes,
AllowRaggedTiles,
value_t, ///< Dot-product vector type along the K-axis (dp1 specialization)
load_algorithm::CongruousCopy> ///< Algorithm for loading a shared tile of KxL matrix data (CongruousCopy specialization)
{
//-------------------------------------------------------------------------
// Constants and types
//-------------------------------------------------------------------------
/// Dot-product vector type along the K-axis
typedef value_t dp_vector_t;
enum
{
/// Number of value_t in a dp_vector_t
DpVectorItems = divide_assert<sizeof(dp_vector_t), sizeof(value_t)>::value,
/// Number of dp_vector_t in a block-wide tile
BlockDpVectors = BlockDpVectorsK * BlockDpVectorsL,
/// Number of dp_vector_t in a thread-tile
ThreadDpVectors = divide_assert<BlockDpVectors, BlockThreads>::value,
};
/// Data movement type, coarsened by LeadingDimAlignBytes, capped by the
/// smaller of either ThreadDpVectors or BlockDpVectorsL
typedef io_vector<
dp_vector_t,
__NV_STD_MIN(ThreadDpVectors, BlockDpVectorsL),
LeadingDimAlignBytes>
ldg_vector_t;
enum
{
/// Number of dp_vector_t per ldg_vector_t
LdgVectorDpVectors = ldg_vector_t::VectorItems,
/// Number of value_t per ldg_vector_t
LdgVectorItems = LdgVectorDpVectors * DpVectorItems,
/// Total number of ldg_vector_t within each block-wide tile
BlockLdgVectors = divide_assert<BlockDpVectors, LdgVectorDpVectors>::value,
/// Extent of the block-wide tile in ldg_vector_t along L-axis
BlockLdgVectorsL = divide_assert<BlockDpVectorsL, LdgVectorDpVectors>::value,
/// Extent of the block-wide tile in ldg_vector_t along K-axis
BlockLdgVectorsK = BlockDpVectorsK,
/// Number of ldg_vector_t within each thread-tile
ThreadLdgVectors = divide_assert<BlockLdgVectors, BlockThreads>::value,
/// Extent of the thread tile in ldg_vector_t along L-axis
ThreadLdgVectorsL = __NV_STD_MAX(1, (BlockLdgVectorsL / BlockThreads)),
/// Extent of the thread tile in ldg_vector_t along K-axis
ThreadLdgVectorsK = divide_assert<ThreadLdgVectors, ThreadLdgVectorsL>::value,
/// Number of ldg_vector_t within each stripmine-tile
StripmineLdgVectors = BlockThreads,
/// Extent of the stripmine tile in ldg_vector_t along L-axis
StripmineLdgVectorsL = __NV_STD_MIN(BlockLdgVectorsL, StripmineLdgVectors),
/// Extent of the stripmine tile in ldg_vector_t along K-axis
StripmineLdgVectorsK = divide_assert<StripmineLdgVectors, StripmineLdgVectorsL>::value,
/// Alignment in dp_vector_t along L needed for committing prefetch
AlignmentDpVectorsL = LdgVectorDpVectors,
};
/// Predicate bit vector
typedef uint64_t predicate_mask_t;
//-------------------------------------------------------------------------
// Assert assumptions
//-------------------------------------------------------------------------
static_assert(
(ThreadLdgVectors <= sizeof(predicate_mask_t) * 8),
"Predicate mask type does not contain enough bits for encoding load predicates");
//-------------------------------------------------------------------------
// Members
//-------------------------------------------------------------------------
/// Input pointer to matrix in ldg_vector_t
ldg_vector_t *d_matrix_ldgvecs;
/// Extent of the input matrix in ldg_vector_t along the L-axis
int matrix_ldgvecs_l;
/// Thread block's ending ldg_vector_t coordinate (k) within the input matrix (one-past)
int block_end_ldgvec_k;
/// Predicate bits for guarding ldg_vector_t loads within "whole-k" block-wide tiles
predicate_mask_t guard;
/// Predicate bits for guarding ldg_vector_t loads within the final block-wide "residue" tile
predicate_mask_t residue_guard;
/// Iteration span in "whole-k" block-wide tiles
int wholek_tiles_remaining;
/// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the K-axis
int matrix_ldgvec_stride_k;
/// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the L-axis
int matrix_ldgvec_stride_l;
/// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
int2 block_thread_ldgvec_coords;
/// Thread-wide tile of prefetch data
ldg_vector_t thread_tile[ThreadLdgVectorsK][ThreadLdgVectorsL];
//-------------------------------------------------------------------------
// Constructor API
//-------------------------------------------------------------------------
/// Constructor
inline __device__
block_loader(
value_t *d_matrix_items, ///< Input pointer to matrix in value_t
int matrix_items_l, ///< Extent of the input matrix in value_t along the L-axis
int matrix_items_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis
int matrix_items_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis
int2 matrix_block_item_coords, ///< value_t coordinates (l, k) of first block-wide tile within the input matrix
int block_end_item_k) ///< Thread block's ending coordinate (k) within the input matrix (one-past)
:
block_end_ldgvec_k(block_end_item_k),
guard(0),
residue_guard(0)
{
matrix_ldgvecs_l = matrix_items_l / LdgVectorItems;
matrix_ldgvec_stride_k = matrix_items_stride_k / LdgVectorItems,
matrix_ldgvec_stride_l = matrix_items_stride_l;
// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
block_thread_ldgvec_coords = make_int2(
threadIdx.x % BlockLdgVectorsL, // l-coordinate
threadIdx.x / BlockLdgVectorsL); // k-coordinate
// ldg_vector_t coordinates (l, k) of first block-wide tile within the input matrix
int2 matrix_block_ldgvec_coords = make_int2(
matrix_block_item_coords.x / LdgVectorItems, // l-coordinate
matrix_block_item_coords.y); // k-coordinate
// Iteration span in ldg_vector_t
int span_ldgvec_k = (block_end_item_k - matrix_block_item_coords.y);
// ldg_vector_t coordinates (l, k) of first thread-tile tile within the input matrix
int2 matrix_thread_ldgvec_coords = make_int2(
block_thread_ldgvec_coords.x + matrix_block_ldgvec_coords.x,
block_thread_ldgvec_coords.y + matrix_block_ldgvec_coords.y);
// Iteration range in "whole-k" block-wide tiles
wholek_tiles_remaining = span_ldgvec_k / BlockLdgVectorsK;
// Extent of final residue-tile in ldg_vector_t along K-axis
int residue_ldgvecs_k = span_ldgvec_k % BlockLdgVectorsK;
// Initialize I/O predicates
if (AllowRaggedTiles)
{
// Outer thread-tile ldg_vector_t iteration (K-axis)
#pragma unroll
for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
{
int block_ldgvec_k = block_thread_ldgvec_coords.y + (thread_ldgvec_k * StripmineLdgVectorsK);
// Whether block_ldgvec_coords.y is valid in the final residue tile
predicate_mask_t valid_k = (block_ldgvec_k < residue_ldgvecs_k);
// Inner thread-tile ldg_vector_t iteration (L-axis)
#pragma unroll
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
{
int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
// Whether block_ldgvec_coords.x is valid any block-wide tile
predicate_mask_t valid_l = (matrix_block_ldgvec_coords.x + block_ldgvec_l < matrix_ldgvecs_l);
// Linear index of ldg_vector_t load
int ldgvec_idx = thread_ldgvec_l + (thread_ldgvec_k * ThreadLdgVectorsL);
// Set predicate guard bits
guard |= (valid_l << ldgvec_idx);
residue_guard |= ((valid_l & valid_k) << ldgvec_idx);
}
}
// Promote residue-guard to primary-guard if no full tiles remain
if (!wholek_tiles_remaining)
{
guard = residue_guard;
}
}
// Update the input pointer to be matrix_thread_ldgvec_coords
this->d_matrix_ldgvecs =
reinterpret_cast<ldg_vector_t*>(d_matrix_items) +
(matrix_thread_ldgvec_coords.y * matrix_ldgvec_stride_k) +
(matrix_thread_ldgvec_coords.x * matrix_ldgvec_stride_l);
}
//-------------------------------------------------------------------------
// Loader API
//-------------------------------------------------------------------------
/**
* Request the current block-wide tile
*/
inline __device__
void request()
{
// Outer thread-tile ldg_vector_t iteration (K-axis)
#pragma unroll
for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
{
// Inner thread-tile ldg_vector_t iteration (L-axis)
#pragma unroll
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
{
// Linear index of ldg_vector_t load
int ldgvec_idx = (thread_ldgvec_k * ThreadLdgVectorsL) + thread_ldgvec_l;
// Unpack predicate guard
predicate_mask_t valid = ((guard >> ldgvec_idx) & 1);
if (!AllowRaggedTiles || valid)
{
// Perform load
thread_tile[thread_ldgvec_k][thread_ldgvec_l].load(
d_matrix_ldgvecs +
(thread_ldgvec_k * StripmineLdgVectorsK * matrix_ldgvec_stride_k) +
(thread_ldgvec_l * StripmineLdgVectorsL * matrix_ldgvec_stride_l));
}
else
{
// Zero-initialize
#pragma unroll
for (int dpvec = 0; dpvec < LdgVectorDpVectors; ++dpvec)
thread_tile[thread_ldgvec_k][thread_ldgvec_l].buff[dpvec] = 0;
}
}
}
}
/**
* Advance the loader to the next block-wide tile in the K-axis
*/
inline __device__
void next()
{
d_matrix_ldgvecs += (matrix_ldgvec_stride_k * BlockLdgVectorsK);
if (AllowRaggedTiles)
{
--wholek_tiles_remaining;
// Promote residue-guard to primary-guard if no full tiles remain
if (!wholek_tiles_remaining)
{
guard = residue_guard;
}
}
}
/**
* Commit the previously-requested block-wide tile to shared memory
*
* NB: To facilitate padding for avoiding shared memory bank conflicts, we
* allow the row stride SmemDpVectorsL to be arbitrarily bigger than the
* tile width BlockDpVectorsL.
*/
template <int SmemDpVectorsL>
inline __device__
void commit(
dp_vector_t (&scratch_tile)[BlockDpVectorsK][SmemDpVectorsL])
{
static_assert(SmemDpVectorsL >= BlockDpVectorsL, "Row stride must be >= tile width.");
// Outer thread-tile ldg_vector_t iteration (K-axis)
#pragma unroll
for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
{
int block_ldgvec_k = block_thread_ldgvec_coords.y + (thread_ldgvec_k * StripmineLdgVectorsK);
// Inner thread-tile ldg_vector_t iteration (L-axis)
#pragma unroll
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
{
int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
thread_tile[thread_ldgvec_k][thread_ldgvec_l].store(
&scratch_tile[block_ldgvec_k][block_ldgvec_l * LdgVectorDpVectors]);
}
}
}
};
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,536 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Tile-loading abstraction for thread blocks
*/
#include "../util/util.h"
namespace cutlass {
namespace gemm {
/******************************************************************************
* block_loader (CongruousCopy + idp4 specialization)
******************************************************************************/
/**
* \brief A three-phase data loading abstraction (prefetch, commit, and
* advance) for iterating over ranges of block-wide matrix tiles.
* (CongruousCopy + idp4 specialization)
*
* Each iteration sequence produces a KxL (height-by-width) block-wide tile of
* value_t in shared memory. The layout of the shared block-wide tile is
* a row-major (L-major) tiling of int32_t dp_vector_t, which are themselves
* column-major (K-major) vectors of int8_t value_t. Its dimensions are:
* K = BlockDpVectorsK * (sizeof(dp_vector_t) / sizeof(value_t)
* L = BlockDpVectorsL
*
* The data is copied from a corresponding tile of global matrix data whose
* layout of value_t is also L-major. This constitutes a CongruousCopy between
* the L-major global tile and the L-major shared tile.
*
* NB: The K-major value_t in shared dp_vector_t are imperfectly congruous
* with the L-major value_t in global memory. As a result, the granularity
* of data transfer is a "dp-square" of (DpVectorItems * DpVectorItems) values
* that must be transposed from L-oriented dp_vector_t to K-oriented
* dp_vector_t prior to commitment.
*
* NB: Consecutive threads within a block are mapped in L-major
* fashion across a first-set of squares within their global tile. Successive
* sets of squares are then strip-mined as necessary down the K-axis. These
* discontiguous squares comprise the thread's "slice" of the block-wide tile.
*/
template <
int BlockThreads, ///< Number of threads in each thread block (blockDim.x)
int _BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height)
int _BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width)
int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension
bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
>
struct block_loader<
BlockThreads,
_BlockDpVectorsK,
_BlockDpVectorsL,
int8_t, ///< Input matrix value type (idp4 specialization)
LeadingDimAlignBytes,
AllowRaggedTiles,
int32_t, ///< Dot-product vector type along the K-axis (idp4 specialization)
load_algorithm::CongruousCopy> ///< Algorithm for loading a shared tile of KxL matrix data (CrosswiseCopy specialization)
{
//-------------------------------------------------------------------------
// Constants and types
//-------------------------------------------------------------------------
/// Input matrix value type
typedef int8_t value_t;
/// Dot-product vector type along the K-axis
typedef int32_t dp_vector_t;
enum
{
/// Number of value_t in a dp_vector_t
DpVectorItems = divide_assert<sizeof(dp_vector_t), sizeof(value_t)>::value,
/// Number of dp_vector_t in a block-wide tile
BlockDpVectors = _BlockDpVectorsK * _BlockDpVectorsL,
/// Number of dp_vector_t in a thread-tile
ThreadDpVectors = divide_assert<BlockDpVectors, BlockThreads>::value,
/// Number of dp_vector_t in a dp-square
SquareDpVectors = DpVectorItems,
/// Number of dp-square tiles in a thread-tile
ThreadSquares = divide_assert<ThreadDpVectors, SquareDpVectors>::value,
/// Extent of block-wide tile in transposed dp_vector_t along the K-axis (height)
BlockTransDpVectorsK = _BlockDpVectorsK * DpVectorItems,
/// Extent of block-wide tile in transposed dp_vector_t along the L-axis (height)
BlockTransDpVectorsL = divide_assert<_BlockDpVectorsL, DpVectorItems>::value,
};
/// Load-from-global data movement type, coarsened by LeadingDimAlignBytes, capped by the
/// smaller of either ThreadSquares or BlockTransDpVectorsL
typedef io_vector<
dp_vector_t,
__NV_STD_MIN(ThreadSquares, BlockTransDpVectorsL),
LeadingDimAlignBytes>
ldg_vector_t;
/// Store-to-shared data movement type equivalent to a dp-square
typedef io_vector<
dp_vector_t,
SquareDpVectors>
sts_vector_t;
enum
{
/// Number of dp_vector_t per ldg_vector_t
LdgVectorDpVectors = ldg_vector_t::VectorItems,
/// Number of value_t per ldg_vector_t
LdgVectorItems = LdgVectorDpVectors * DpVectorItems,
/// Total number of ldg_vector_t within each block-wide tile
BlockLdgVectors = divide_assert<BlockDpVectors, LdgVectorDpVectors>::value,
/// Extent of the block-wide tile in ldg_vector_t along L-axis
BlockLdgVectorsL = divide_assert<BlockTransDpVectorsL, LdgVectorDpVectors>::value,
/// Extent of the block-wide tile in ldg_vector_t along K-axis
BlockLdgVectorsK = BlockTransDpVectorsK,
/// Number of ldg_vector_t within each thread-tile
ThreadLdgVectors = divide_assert<BlockLdgVectors, BlockThreads>::value,
/// Extent of the thread tile in ldg_vector_t along L-axis
ThreadLdgVectorsL = __NV_STD_MAX(1, (BlockLdgVectorsL / BlockThreads)),
/// Extent of the thread tile in ldg_vector_t along K-axis
ThreadLdgVectorsK = divide_assert<ThreadLdgVectors, ThreadLdgVectorsL>::value,
/// Extent of the thread tile in dp-square tiles along K-axis
ThreadSquaresK = divide_assert<ThreadLdgVectorsK, SquareDpVectors>::value,
/// Number of ldg_vector_t within each stripmine-tile
StripmineLdgVectors = BlockThreads * SquareDpVectors,
/// Extent of the stripmine tile in ldg_vector_t along L-axis
StripmineLdgVectorsL = __NV_STD_MIN(BlockLdgVectorsL, BlockThreads),
/// Extent of the stripmine tile in ldg_vector_t along K-axis
StripmineLdgVectorsK = divide_assert<StripmineLdgVectors, StripmineLdgVectorsL>::value,
/// Extent of the stripmine tile in dp-square tiles along K-axis
StripmineSquaresK = divide_assert<StripmineLdgVectorsK, SquareDpVectors>::value,
/// Alignment in dp_vector_t along L needed for committing prefetch
AlignmentDpVectorsL = LdgVectorDpVectors,
};
/// Predicate mask type
typedef uint32_t predicate_mask_t;
//-------------------------------------------------------------------------
// Assert assumptions
//-------------------------------------------------------------------------
static_assert((LeadingDimAlignBytes >= 4) && (LeadingDimAlignBytes % 4 == 0),
"Alignment for matrix operands to IGEMM must be a multiple of 4 bytes.");
static_assert(
(ThreadLdgVectors <= sizeof(predicate_mask_t) * 8),
"Predicate mask type does not contain enough bits for encoding load predicates");
//-------------------------------------------------------------------------
// Members
//-------------------------------------------------------------------------
/// Input pointer to matrix in ldg_vector_t
ldg_vector_t *d_matrix_ldgvecs;
/// Extent of the input matrix in ldg_vector_t along the L-axis
int matrix_ldgvecs_l;
/// Thread block's ending ldg_vector_t coordinate (k) within the input matrix (one-past)
int block_end_ldgvec_k;
/// Predicate bits for guarding ldg_vector_t loads within "whole-k" block-wide tiles
predicate_mask_t guard;
/// Predicate bits for guarding ldg_vector_t loads within the final block-wide "residue" tile
predicate_mask_t residue_guard;
/// Iteration span in "whole-k" block-wide tiles
int wholek_tiles_remaining;
/// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the K-axis
int matrix_ldgvec_stride_k;
/// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the L-axis
int matrix_ldgvec_stride_l;
/// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
int2 block_thread_ldgvec_coords;
/// Thread-wide tile of prefetch data
ldg_vector_t thread_tile[ThreadSquaresK][SquareDpVectors][ThreadLdgVectorsL];
//-------------------------------------------------------------------------
// Utility methods
//-------------------------------------------------------------------------
/**
* \brief Byte-permute. Pick four arbitrary bytes from two 32-bit registers, and reassemble them into a 32-bit destination register. For SM2.0 or later.
*
* \par
* The bytes in the two source registers \p a and \p b are numbered from 0 to 7:
* {\p b, \p a} = {{b7, b6, b5, b4}, {b3, b2, b1, b0}}. For each of the four bytes
* {b3, b2, b1, b0} selected in the return value, a 4-bit selector is defined within
* the four lower "nibbles" of \p index: {\p index } = {n7, n6, n5, n4, n3, n2, n1, n0}
*
* \par Snippet
* The code snippet below illustrates byte-permute.
* \par
* \code
* #include <cub/cub.cuh>
*
* __global__ void ExampleKernel(...)
* {
* int a = 0x03020100;
* int b = 0x07060504;
* int index = 0x00007531;
*
* int selected = prmt(a, b, index); // 0x07050301
*
* \endcode
*
*/
inline __device__
int32_t prmt(int32_t a, int32_t b, unsigned int index)
{
int ret;
asm volatile("prmt.b32 %0, %1, %2, %3;" : "=r"(ret) : "r"(a), "r"(b), "r"(index));
return ret;
}
/**
* Convert a "dp-square" from L-major to K-major
*/
inline __device__
void transpose_dp_square(dp_vector_t (&dp_square)[SquareDpVectors])
{
// Transpose dp_vector_t squares
int32_t y = prmt(dp_square[0], dp_square[1], 0x00007362);
int32_t w = prmt(dp_square[2], dp_square[3], 0x00007362);
int32_t x = prmt(dp_square[0], dp_square[1], 0x00005140);
int32_t z = prmt(dp_square[2], dp_square[3], 0x00005140);
dp_square[0] = prmt(x, z, 0x00005410);
dp_square[1] = prmt(x, z, 0x00007632);
dp_square[2] = prmt(y, w, 0x00005410);
dp_square[3] = prmt(y, w, 0x00007632);
}
//-------------------------------------------------------------------------
// Constructor API
//-------------------------------------------------------------------------
/// Constructor
inline __device__
block_loader(
value_t *d_matrix_items, ///< Input pointer to matrix in value_t
int matrix_items_l, ///< Extent of the input matrix in value_t along the L-axis
int matrix_items_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis
int matrix_items_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis
int2 matrix_block_item_coords, ///< value_t coordinates (l, k) of first block-wide tile within the input matrix
int block_end_item_k) ///< Thread block's ending coordinate (k) within the input matrix (one-past)
:
block_end_ldgvec_k(block_end_item_k),
guard(0),
residue_guard(0)
{
matrix_ldgvecs_l = matrix_items_l / LdgVectorItems;
matrix_ldgvec_stride_k = matrix_items_stride_k / LdgVectorItems,
matrix_ldgvec_stride_l = matrix_items_stride_l;
// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
block_thread_ldgvec_coords = make_int2(
threadIdx.x % BlockLdgVectorsL, // l-coordinate
(threadIdx.x / BlockLdgVectorsL) * SquareDpVectors); // k-coordinate
// ldg_vector_t coordinates (l, k) of first block-wide tile within the input matrix
int2 matrix_block_ldgvec_coords = make_int2(
matrix_block_item_coords.x / LdgVectorItems, // l-coordinate
matrix_block_item_coords.y); // k-coordinate
// Iteration span in ldg_vector_t
int span_ldgvec_k = (block_end_item_k - matrix_block_item_coords.y);
// ldg_vector_t coordinates (l, k) of first thread-tile tile within the input matrix
int2 matrix_thread_ldgvec_coords = make_int2(
block_thread_ldgvec_coords.x + matrix_block_ldgvec_coords.x,
block_thread_ldgvec_coords.y + matrix_block_ldgvec_coords.y);
// Iteration range in "whole-k" block-wide tiles
wholek_tiles_remaining = span_ldgvec_k / BlockLdgVectorsK;
// Extent of final residue-tile in ldg_vector_t along K-axis
int residue_ldgvecs_k = span_ldgvec_k % BlockLdgVectorsK;
// Initialize I/O predicates
if (AllowRaggedTiles)
{
// Iterate through rows of squares in thread tile
#pragma unroll
for (int thread_square_k = 0; thread_square_k < ThreadSquaresK; ++thread_square_k)
{
// Iterate through rows of dp_vector_t in each square
#pragma unroll
for (int square_dpvec = 0; square_dpvec < SquareDpVectors; ++square_dpvec)
{
// ldg_vector_t K-coordinate in block-wide tile (K-axis strip-mining of ldg_vector_t within block-tile)
int block_ldgvec_k =
block_thread_ldgvec_coords.y +
(thread_square_k * StripmineLdgVectorsK) +
square_dpvec;
// Whether block_ldgvec_coords.y is valid in the final residue tile
predicate_mask_t valid_k = (block_ldgvec_k < residue_ldgvecs_k);
// L-axis strip-mining of block-tile
#pragma unroll
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
{
// ldg_vector_t L-coordinate in block-wide tile (L-axis strip-mining of ldg_vector_t within block-tile)
int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
// Whether block_ldgvec_coords.x is valid any block-wide tile
predicate_mask_t valid_l = (matrix_block_ldgvec_coords.x + block_ldgvec_l < matrix_ldgvecs_l);
// Linear index of ldg_vector_t load
int ldgvec_idx =
(thread_square_k * SquareDpVectors * ThreadLdgVectorsL) +
(square_dpvec * ThreadLdgVectorsL) +
thread_ldgvec_l;
// Set predicate guard bits
guard |= (valid_l << ldgvec_idx);
residue_guard |= ((valid_l & valid_k) << ldgvec_idx);
}
}
}
// Promote residue-guard to primary-guard if no full tiles remain
if (!wholek_tiles_remaining)
{
guard = residue_guard;
}
}
// Update the input pointer to be matrix_thread_ldgvec_coords
this->d_matrix_ldgvecs =
reinterpret_cast<ldg_vector_t*>(d_matrix_items) +
(matrix_thread_ldgvec_coords.y * matrix_ldgvec_stride_k) +
(matrix_thread_ldgvec_coords.x * matrix_ldgvec_stride_l);
}
//-------------------------------------------------------------------------
// Loader API
//-------------------------------------------------------------------------
/**
* Request the current block-wide tile
*/
inline __device__
void request()
{
// Each thread iterates through the ldg_vector_t in its thread tile
// Iterate through rows of squares in thread tile
#pragma unroll
for (int thread_square_k = 0; thread_square_k < ThreadSquaresK; ++thread_square_k)
{
// Iterate through rows of dp_vector_t in each square
#pragma unroll
for (int square_dpvec = 0; square_dpvec < SquareDpVectors; ++square_dpvec)
{
// Iterate through ldg_vector_t in each row
#pragma unroll
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
{
// Linear index of ldg_vector_t load
int ldgvec_idx =
(thread_square_k * SquareDpVectors * ThreadLdgVectorsL) +
(square_dpvec * ThreadLdgVectorsL) +
thread_ldgvec_l;
// Unpack predicate guard
predicate_mask_t valid = ((guard >> ldgvec_idx) & 1);
if (!AllowRaggedTiles || valid)
{
// Perform load
thread_tile[thread_square_k][square_dpvec][thread_ldgvec_l].load(
d_matrix_ldgvecs +
(((thread_square_k * StripmineLdgVectorsK) + square_dpvec) * matrix_ldgvec_stride_k) +
(thread_ldgvec_l * StripmineLdgVectorsL * matrix_ldgvec_stride_l));
}
else
{
// Zero-initialize
#pragma unroll
for (int dpvec = 0; dpvec < LdgVectorDpVectors; ++dpvec)
thread_tile[thread_square_k][square_dpvec][thread_ldgvec_l].buff[dpvec] = 0;
}
}
}
}
}
/**
* Advance the loader to the next block-wide tile in the K-axis
*/
inline __device__
void next()
{
d_matrix_ldgvecs += (matrix_ldgvec_stride_k * BlockLdgVectorsK);
if (AllowRaggedTiles)
{
--wholek_tiles_remaining;
// Promote residue-guard to primary-guard if no full tiles remain
if (!wholek_tiles_remaining)
{
guard = residue_guard;
}
}
}
/**
* Commit the previously-requested block-wide tile to shared memory
*
* NB: To facilitate padding for avoiding shared memory bank conflicts, we
* allow the row stride SmemDpVectorsL to be arbitrarily bigger than the
* tile width BlockDpVectorsL.
*/
template <int SmemDpVectorsL>
inline __device__
void commit(
dp_vector_t (&scratch_tile)[_BlockDpVectorsK][SmemDpVectorsL])
{
static_assert(SmemDpVectorsL >= _BlockDpVectorsL, "Row stride must be >= tile width.");
// Square K-coordinate of thread tile in block-wide tile
int block_thread_square_k = block_thread_ldgvec_coords.y / SquareDpVectors;
// Iterate through rows of squares in thread tile
#pragma unroll
for (int thread_square_k = 0; thread_square_k < ThreadSquaresK; ++thread_square_k)
{
// Square K-coordinate in block-wide tile (K-axis strip-mining of squares within block-tile)
int block_square_k = block_thread_square_k + (thread_square_k * StripmineSquaresK);
// Iterate through ldg_vector_t in each row
#pragma unroll
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
{
// ldg_vector_t L-coordinate in block-wide tile (L-axis strip-mining of ldg_vector_t within block-tile)
int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
// Iterate through squares in each ldg_vector_t
#pragma unroll
for (int ldgvec_dpvec_l = 0; ldgvec_dpvec_l < LdgVectorDpVectors; ++ldgvec_dpvec_l)
{
// Square L-coordinate in block-wide tile (L-axis raking of square-slices within ldg_vector_t)
int block_square_l = (block_ldgvec_l * LdgVectorDpVectors) + ldgvec_dpvec_l;
// Assemble square of L-major dp_vector_t from stack of slices
sts_vector_t square;
// Iterate through rows of dp_vector_t in each square
#pragma unroll
for (int square_dpvec = 0; square_dpvec < SquareDpVectors; ++square_dpvec)
{
square.buff[square_dpvec] = thread_tile[thread_square_k][square_dpvec][thread_ldgvec_l].buff[ldgvec_dpvec_l];
}
// Un-transpose square from L-major to K-major
transpose_dp_square(square.buff);
// Store dp-square
square.store(&scratch_tile[block_square_k][block_square_l * SquareDpVectors]);
}
}
}
}
};
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,403 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Tile-loading abstraction for thread blocks
*/
#include "../util/util.h"
namespace cutlass {
namespace gemm {
/******************************************************************************
* block_loader (CrosswiseCopy specialization)
******************************************************************************/
/**
* \brief A three-phase data loading abstraction (prefetch, commit, and
* advance) for iterating over ranges of block-wide matrix tiles.
* (CrosswiseCopy specialization)
*
* Each iteration sequence produces a KxL (height-by-width) block-wide tile of
* value_t in shared memory. The layout of the shared block-wide tile is
* a row-major (L-major) tiling of dp_vector_t items, which are themselves
* column-major (K-major) vectors of value_t. Its dimensions are:
* K = BlockDpVectorsK * (sizeof(dp_vector_t) / sizeof(value_t)
* L = BlockDpVectorsL
*
* The data is copied from a corresponding tile of global matrix data whose
* layout of value_t is K-major. This constitutes a CrosswiseCopy between
* the K-major global tile and the L-major shared tile.
*
* NB: The orientation of dp_vector_t components in shared memory is congruous
* with the global matrix data, so we can use dp_vector_t as the minimum
* granularity of data transfer without any intermediate {dis|re}assembly
* of its value_t components. However, the global and shared memory layouts
* of dp_vector_t items are cross-wise with respect to each other, so any
* further LDG-vectorization of dp_vector_t data requires intermediate
* disassembly into dp_vector_t components to be stored individually into
* the shared tile.
*
* NB: Consecutive threads within a block are mapped in K-major
* fashion down a first set of LDG-vectors of dp_vector_t within their global
* tile. Successive sets of LDG-vectors are then strip-mined as necessary
* across the L-axis. These discontiguous LDG-vectors comprise the thread's
* "slice" of the block-wide tile.
*/
template <
int BlockThreads, ///< Number of threads in each thread block (blockDim.x)
int BlockDpVectorsK, ///< Extent of block-wide tile in dp_vector_t along the K-axis (height)
int BlockDpVectorsL, ///< Extent of block-wide tile in dp_vector_t along the L-axis (width)
typename value_t, ///< Input matrix value type
int LeadingDimAlignBytes, ///< Byte alignment of input matrix leading dimension
bool AllowRaggedTiles, ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
typename dp_vector_t> ///< Dot-product vector type along the K-axis
struct block_loader<
BlockThreads,
BlockDpVectorsK,
BlockDpVectorsL,
value_t,
LeadingDimAlignBytes,
AllowRaggedTiles,
dp_vector_t,
load_algorithm::CrosswiseCopy> ///< Algorithm for loading a shared tile of KxL matrix data (CrosswiseCopy specialization)
{
//-------------------------------------------------------------------------
// Constants and types
//-------------------------------------------------------------------------
enum
{
/// Number of value_t in a dp_vector_t
DpVectorItems = divide_assert<sizeof(dp_vector_t), sizeof(value_t)>::value,
/// Number of dp_vector_t in a block-wide tile
BlockDpVectors = BlockDpVectorsK * BlockDpVectorsL,
/// Number of dp_vector_t in a thread-tile
ThreadDpVectors = divide_assert<BlockDpVectors, BlockThreads>::value,
};
/// Data movement type, coarsened by LeadingDimAlignBytes, capped by the
/// smaller of either ThreadDpVectors or BlockDpVectorsK
typedef io_vector<
dp_vector_t,
__NV_STD_MIN(ThreadDpVectors, BlockDpVectorsK),
LeadingDimAlignBytes>
ldg_vector_t;
enum
{
/// Number of dp_vector_t per ldg_vector_t
LdgVectorDpVectors = ldg_vector_t::VectorItems,
/// Number of value_t per ldg_vector_t
LdgVectorItems = LdgVectorDpVectors * DpVectorItems,
/// Total number of ldg_vector_t within each block-wide tile
BlockLdgVectors = divide_assert<BlockDpVectors, LdgVectorDpVectors>::value,
/// Extent of the block-wide tile in ldg_vector_t along K-axis
BlockLdgVectorsK = divide_assert<BlockDpVectorsK, LdgVectorDpVectors>::value,
/// Extent of the block-wide tile in ldg_vector_t along L-axis
BlockLdgVectorsL = BlockDpVectorsL,
/// Number of ldg_vector_t within each thread-tile
ThreadLdgVectors = divide_assert<BlockLdgVectors, BlockThreads>::value,
/// Extent of the thread tile in ldg_vector_t along K-axis
ThreadLdgVectorsK = __NV_STD_MAX(1, (BlockLdgVectorsK / BlockThreads)),
/// Extent of the thread tile in ldg_vector_t along L-axis
ThreadLdgVectorsL = divide_assert<ThreadLdgVectors, ThreadLdgVectorsK>::value,
/// Number of ldg_vector_t within each stripmine-tile
StripmineLdgVectors = BlockThreads,
/// Extent of the stripmine tile in ldg_vector_t along K-axis
StripmineLdgVectorsK = __NV_STD_MIN(BlockLdgVectorsK, StripmineLdgVectors),
/// Extent of the stripmine tile in ldg_vector_t along L-axis
StripmineLdgVectorsL = divide_assert<StripmineLdgVectors, StripmineLdgVectorsK>::value,
/// Alignment in dp_vector_t along L needed for committing prefetch
AlignmentDpVectorsL = 1,
};
/// Predicate bit vector
typedef uint64_t predicate_mask_t;
//-------------------------------------------------------------------------
// Assert assumptions
//-------------------------------------------------------------------------
static_assert(
(ThreadLdgVectors <= sizeof(predicate_mask_t) * 8),
"Predicate mask type does not contain enough bits for encoding load predicates");
//-------------------------------------------------------------------------
// Members
//-------------------------------------------------------------------------
/// Input pointer to matrix in ldg_vector_t
ldg_vector_t *d_matrix_ldgvecs;
/// Extent of the input matrix in ldg_vector_t along the L-axis
int matrix_ldgvecs_l;
/// Thread block's ending ldg_vector_t coordinate (k) within the input matrix (one-past)
int block_end_ldgvec_k;
/// Predicate bits for guarding ldg_vector_t loads within "whole-k" block-wide tiles
predicate_mask_t guard;
/// Predicate bits for guarding ldg_vector_t loads within the final block-wide "residue" tile
predicate_mask_t residue_guard;
/// Iteration span in "whole-k" block-wide tiles
int wholek_tiles_remaining;
/// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the K-axis
int matrix_ldgvec_stride_k;
/// Distance in ldg_vector_t within pitched-linear memory between successive coordinates along the L-axis
int matrix_ldgvec_stride_l;
/// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
int2 block_thread_ldgvec_coords;
/// Thread-wide tile of prefetch data
ldg_vector_t thread_tile[ThreadLdgVectorsK][ThreadLdgVectorsL];
//-------------------------------------------------------------------------
// Constructor API
//-------------------------------------------------------------------------
/// Constructor
inline __device__
block_loader(
value_t *d_matrix_items, ///< Input pointer to matrix in value_t
int matrix_items_l, ///< Extent of the input matrix in value_t along the L-axis
int matrix_items_stride_k, ///< Distance in value_t within pitched-linear memory between successive coordinates along the K-axis
int matrix_items_stride_l, ///< Distance in value_t within pitched-linear memory between successive coordinates along the L-axis
int2 matrix_block_item_coords, ///< value_t coordinates (l, k) of first block-wide tile within the input matrix
int block_end_item_k) ///< Thread block's ending coordinate (k) within the input matrix (one-past)
:
block_end_ldgvec_k(block_end_item_k),
guard(0),
residue_guard(0)
{
matrix_ldgvecs_l = matrix_items_l;
matrix_ldgvec_stride_k = matrix_items_stride_k;
matrix_ldgvec_stride_l = (matrix_items_stride_l / LdgVectorItems);
// ldg_vector_t coordinates (l, k) of thread-tile within the block-wide tile
block_thread_ldgvec_coords = make_int2(
(threadIdx.x / BlockLdgVectorsK), // l-coordinate
(threadIdx.x % BlockLdgVectorsK)); // k-coordinate
// ldg_vector_t coordinates (l, k) of first block-wide tile within the input matrix
int2 matrix_block_ldgvec_coords = make_int2(
matrix_block_item_coords.x, // l-coordinate
matrix_block_item_coords.y / LdgVectorItems); // k-coordinate
// Iteration span in ldg_vector_t
int span_ldgvec_k = (block_end_item_k - matrix_block_item_coords.y) / LdgVectorItems;
// ldg_vector_t coordinates (l, k) of first thread-tile tile within the input matrix
int2 matrix_thread_ldgvec_coords = make_int2(
block_thread_ldgvec_coords.x + matrix_block_ldgvec_coords.x,
block_thread_ldgvec_coords.y + matrix_block_ldgvec_coords.y);
// Iteration range in "whole-k" block-wide tiles
wholek_tiles_remaining = span_ldgvec_k / BlockLdgVectorsK;
// Extent of final residue-tile in ldg_vector_t along K-axis
int residue_ldgvecs_k = span_ldgvec_k % BlockLdgVectorsK;
// Initialize I/O predicates
if (AllowRaggedTiles)
{
// Outer thread-tile ldg_vector_t iteration (K-axis)
#pragma unroll
for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
{
int block_ldgvec_k = block_thread_ldgvec_coords.y + (thread_ldgvec_k * StripmineLdgVectorsK);
// Whether block_ldgvec_coords.y is valid in the final residue tile
predicate_mask_t valid_k = (block_ldgvec_k < residue_ldgvecs_k);
// Inner thread-tile ldg_vector_t iteration (L-axis)
#pragma unroll
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
{
int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
// Whether block_ldgvec_coords.x is valid any block-wide tile
predicate_mask_t valid_l = (matrix_block_ldgvec_coords.x + block_ldgvec_l < matrix_ldgvecs_l);
// Linear index of ldg_vector_t load
int ldgvec_idx = thread_ldgvec_l + (thread_ldgvec_k * ThreadLdgVectorsL);
// Set predicate guard bits
guard |= (valid_l << ldgvec_idx);
residue_guard |= ((valid_l & valid_k) << ldgvec_idx);
}
}
// Promote residue-guard to primary-guard if no full tiles remain
if (!wholek_tiles_remaining)
{
guard = residue_guard;
}
}
// Update the input pointer to be matrix_thread_ldgvec_coords
this->d_matrix_ldgvecs =
reinterpret_cast<ldg_vector_t*>(d_matrix_items) +
(matrix_thread_ldgvec_coords.y * matrix_ldgvec_stride_k) +
(matrix_thread_ldgvec_coords.x * matrix_ldgvec_stride_l);
}
//-------------------------------------------------------------------------
// Loader API
//-------------------------------------------------------------------------
/**
* Request the current block-wide tile
*/
inline __device__
void request()
{
// Outer thread-tile ldg_vector_t iteration (K-axis)
#pragma unroll
for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
{
// Inner thread-tile ldg_vector_t iteration (L-axis)
#pragma unroll
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
{
// Linear index of ldg_vector_t load
int ldgvec_idx = (thread_ldgvec_k * ThreadLdgVectorsL) + thread_ldgvec_l;
// Unpack predicate guard
predicate_mask_t valid = ((guard >> ldgvec_idx) & 1);
if (!AllowRaggedTiles || valid)
{
// Perform load
thread_tile[thread_ldgvec_k][thread_ldgvec_l].load(
d_matrix_ldgvecs +
(thread_ldgvec_k * StripmineLdgVectorsK * matrix_ldgvec_stride_k) +
(thread_ldgvec_l * StripmineLdgVectorsL * matrix_ldgvec_stride_l));
}
else
{
// Zero-initialize
#pragma unroll
for (int dpvec = 0; dpvec < LdgVectorDpVectors; ++dpvec)
thread_tile[thread_ldgvec_k][thread_ldgvec_l].buff[dpvec] = 0;
}
}
}
}
/**
* Advance the loader to the next block-wide tile in the K-axis
*/
inline __device__
void next()
{
d_matrix_ldgvecs += (matrix_ldgvec_stride_k * BlockLdgVectorsK);
if (AllowRaggedTiles)
{
--wholek_tiles_remaining;
// Promote residue-guard to primary-guard if no full tiles remain
if (!wholek_tiles_remaining)
{
guard = residue_guard;
}
}
}
/**
* Commit the previously-requested block-wide tile to shared memory
*
* NB: To facilitate padding for avoiding shared memory bank conflicts, we
* allow the row stride SmemDpVectorsL to be arbitrarily bigger than the
* tile width BlockDpVectorsL.
*/
template <int SmemDpVectorsL>
inline __device__
void commit(
dp_vector_t (&scratch_tile)[BlockDpVectorsK][SmemDpVectorsL])
{
static_assert(SmemDpVectorsL >= BlockDpVectorsL, "Row stride must be >= tile width.");
// Outer thread-tile ldg_vector_t iteration (K-axis)
#pragma unroll
for (int thread_ldgvec_k = 0; thread_ldgvec_k < ThreadLdgVectorsK; ++thread_ldgvec_k)
{
int block_ldgvec_k = block_thread_ldgvec_coords.y + (thread_ldgvec_k * StripmineLdgVectorsK);
// Inner thread-tile ldg_vector_t iteration (L-axis)
#pragma unroll
for (int thread_ldgvec_l = 0; thread_ldgvec_l < ThreadLdgVectorsL; ++thread_ldgvec_l)
{
int block_ldgvec_l = block_thread_ldgvec_coords.x + (thread_ldgvec_l * StripmineLdgVectorsL);
// Write column of dp_vector_t
#pragma unroll
for (int dpvec = 0; dpvec < LdgVectorDpVectors; ++dpvec)
{
scratch_tile[(block_ldgvec_k * LdgVectorDpVectors) + dpvec][block_ldgvec_l] =
thread_tile[thread_ldgvec_k][thread_ldgvec_l].buff[dpvec];
}
}
}
}
};
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,314 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Tile-loading abstraction for thread blocks
*/
#include "../util/util.h"
namespace cutlass {
namespace gemm {
/**
* block-wide tile loader supporting congruous mapping of data from source and
* destination addressable storage. Typically, this will be used to load a
* block-wide tile from global memory into shared memory.
*
* This enables the caller to specify MatrixAlignBytes guarantees of the input pointer
* and performs memory operations on vectors. This increases the efficiency of
* memory operations and reduces the number of guard predicates needed.
*
*/
template <
bool congruous, ///< Indicates whether the "GEMM K" dimension refers to strided matrix dimension
int BlockThreads, ///< Number of threads participating in the streaming operation
int BlockItemsL, ///< Extent of block-wide tile in value_t along the L-axis (width)
int BlockItemsK, ///< Extent of block-wide tile in value_t along the K-axis (height)
typename value_t, ///< Input matrix value type
int MatrixAlignBytes, ///< Byte alignment of input matrix
bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
>
struct block_loader_wmma
{
//-------------------------------------------------------------------------
// Constants and types
//-------------------------------------------------------------------------
/// Predicate bit vector
typedef uint64_t predicate_mask_t;
/// Data movement type, coarsened by MatrixAlignBytes
typedef io_vector<
value_t,
divide_assert<MatrixAlignBytes, sizeof(value_t)>::value,
MatrixAlignBytes>
ldg_vector_t;
enum
{
/// Number of items per ldg_vector_t
LdgVectorItems = ldg_vector_t::VectorItems,
/// Total number of ldg_vector_t within the block-wide tile
BlockLdgVectors = divide_assert<(BlockItemsL * BlockItemsK), LdgVectorItems>::value,
/// Extent of the block-wide tile in ldg_vector_t along K-axis
BlockLdgVectorsK = BlockItemsK,
/// Extent of the block-wide tile in ldg_vector_t along L-axis
BlockLdgVectorsL = divide_assert<BlockItemsL, LdgVectorItems>::value,
/// Number of ldg_vector_t within each thread tile
ThreadLdgVectors = divide_assert<BlockLdgVectors, BlockThreads>::value,
/// Extent of the thread tile in ldg_vector_t along the L-axis
ThreadLdgVectorsL = __NV_STD_MAX(1, BlockLdgVectorsL / BlockThreads),
/// Block-wide strip-mining distance between ldg_vector_t along the K-axis
BlockLdgVectorStrideK = __NV_STD_MAX(1, BlockThreads / BlockLdgVectorsL),
/// Extent of the thread tile in ldg_vector_t along the K-axis
ThreadLdgVectorsK = divide_assert<BlockLdgVectorsK, BlockLdgVectorStrideK>::value,
};
//-------------------------------------------------------------------------
// Assert assumptions
//-------------------------------------------------------------------------
/// Define assertions
static_assert(ThreadLdgVectorsL * ThreadLdgVectorsK == ThreadLdgVectors,
"Number of vectors must be fully covered by the thread's 2D vector tile.");
/// Predicate masks must be large enough to guard every vector load
static_assert(sizeof(predicate_mask_t) * 8 >= ThreadLdgVectorsL * ThreadLdgVectorsK,
"Predicate bit vector must be large enough to guard every vector load.");
//-------------------------------------------------------------------------
// Members
//-------------------------------------------------------------------------
/// pointer to tile in global memory
const ldg_vector_t *ptr;
/// stride of the matrix in the K-axis
int matrix_values_stride_k;
/// Guard predicate
predicate_mask_t guard;
/// Guard for the last request iteration
predicate_mask_t residue_guard;
/// Number of 'whole' request iterations before encountering the residue
int request_iterations;
/// fetch registers
ldg_vector_t fetch[ThreadLdgVectors];
/// Thread's base offset from the start of a block-wide tile
int thread_offset_l;
/// Thread's basae offset from the start of a block-wide tile
int thread_offset_k;
//-------------------------------------------------------------------------
// Constructor API
//-------------------------------------------------------------------------
/// Constructor
inline __device__
block_loader_wmma(
const value_t *d_matrix, ///< Pointer to input matrix
int matrix_values_l, ///< Extent of the input matrix in value_t along the L-axis
int start_l, ///< Starting location in tile
int dim_k, ///< Inner dimension of tile, used for computing guard predicates
int _matrix_values_stride_k, ///< Stride of K-axis of atrix
int start_k, ///< Tile's starting location
int2 block_begin_item_coords) ///< Thread block's starting value_t coordinates (l, k) within the input matrix
:
ptr(reinterpret_cast<const ldg_vector_t *>(d_matrix)),
matrix_values_stride_k(_matrix_values_stride_k / LdgVectorItems),
guard(0),
residue_guard(0)
{
// Compute block's starting coordinates in units of vectors
int block_base_l = block_begin_item_coords.x / LdgVectorItems;
int block_base_k = block_begin_item_coords.y;
// Compute a thread tiling of the block-wide tile
int tid = threadIdx.x;
thread_offset_l = tid % BlockLdgVectorsL;
thread_offset_k = tid / BlockLdgVectorsL;
// Add the block and thread offsets to the source pointer
ptr += (block_base_l + thread_offset_l) +
(block_base_k + thread_offset_k) * matrix_values_stride_k;
// When AllowRaggedTiles support is enabled, compute a bit vector of guard
// predicates
if (AllowRaggedTiles)
{
if (congruous)
{
request_iterations = (dim_k - start_k) / BlockItemsK;
}
else
{
request_iterations = (matrix_values_l - start_l) / BlockItemsL;
}
#pragma unroll
for (int k_idx = 0; k_idx < ThreadLdgVectorsK; ++k_idx)
{
#pragma unroll
for (int l_idx = 0; l_idx < ThreadLdgVectorsL; ++l_idx)
{
int item = l_idx + k_idx * ThreadLdgVectorsL;
// Global vector L and K indices
int vec_l = l_idx * BlockThreads;
int vec_k = k_idx * BlockLdgVectorStrideK;
predicate_mask_t pred;
predicate_mask_t residue_pred;
if (congruous)
{
pred = (((block_base_l + thread_offset_l + vec_l) * LdgVectorItems < matrix_values_l) ? 1 : 0);
residue_pred = ((block_base_k + thread_offset_k + vec_k < (dim_k % BlockItemsK)) ? 1 : 0);
}
else
{
pred = ((block_base_k + thread_offset_k + vec_k < dim_k) ? 1 : 0);
residue_pred = (((block_base_l + thread_offset_l + vec_l) * LdgVectorItems < (matrix_values_l % BlockItemsL)) ? 1 : 0);
}
// Update the guard and residue_guard word with predicate bits
guard |= (pred << item);
residue_guard |= (residue_pred << item);
}
}
// If there are zero full request iterations, compute the intersection
// with the residue guard.
if (!request_iterations)
{
guard &= residue_guard;
}
}
}
/**
* Request the current block-wide tile from source memory
*/
inline __device__
void request()
{
#pragma unroll
for (int k_idx = 0; k_idx < ThreadLdgVectorsK; ++k_idx)
{
#pragma unroll
for (int l_idx = 0; l_idx < ThreadLdgVectorsL; ++l_idx)
{
int load_idx = l_idx + (k_idx * ThreadLdgVectorsL);
bool pred = !AllowRaggedTiles || (guard & (predicate_mask_t(1) << load_idx));
if (pred)
{
fetch[load_idx].load(
ptr +
(k_idx * BlockLdgVectorStrideK * matrix_values_stride_k) + (l_idx * BlockThreads));
}
else
{
#pragma unroll
for (int elem_idx = 0; elem_idx < LdgVectorItems; ++elem_idx)
{
fetch[load_idx].buff[elem_idx] = 0;
}
}
}
}
}
/// Advance to the next block-wide tile
inline __device__
void next()
{
if (congruous)
{
ptr += BlockItemsK * matrix_values_stride_k;
}
else
{
ptr += BlockLdgVectorsL;
}
// Track number of full iterations to intersect with the residue guard predicates.
if (AllowRaggedTiles)
{
--request_iterations;
if (!request_iterations)
{
guard &= residue_guard;
}
}
}
/// Commit the values to the scratch tile to destination memory.
template <int SmemStride>
inline __device__
void commit(value_t *scratch_tile)
{
static_assert(SmemStride % LdgVectorItems == 0,
"SMEM stride must be divisible by the size of vector loads");
ldg_vector_t *smem_ptr = reinterpret_cast<ldg_vector_t *>(scratch_tile);
smem_ptr += thread_offset_l + thread_offset_k * SmemStride / LdgVectorItems;
#pragma unroll
for (int k_idx = 0; k_idx < ThreadLdgVectorsK; ++k_idx)
{
#pragma unroll
for (int l_idx = 0; l_idx < ThreadLdgVectorsL; ++l_idx)
{
int load_idx = l_idx + (k_idx * ThreadLdgVectorsL);
fetch[load_idx].store(smem_ptr +
(k_idx * BlockLdgVectorStrideK * SmemStride / LdgVectorItems) +
(l_idx * BlockThreads));
}
}
}
};
} // namespace gemm
} // namespace cutlass

669
cutlass/gemm/block_task.h Normal file
View File

@@ -0,0 +1,669 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* A block-wide task abstraction for computing device-wide GEMM
*/
#include <stdint.h>
#include "../util/util.h"
#include "grid_raster.h"
#include "block_loader.h"
#include "k_split_control.h"
#include "thread_accumulator.h"
namespace cutlass {
namespace gemm {
/******************************************************************************
* block_task_policy
******************************************************************************/
/**
* \brief Parameterizable tuning policy for \p block_task
*
* Once parameterized, \p block_task_policy provides the member constant
* \p BlockThreads indicating to the required thread block size
*/
template <
int _BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
int _BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
int _BlockItemsK, ///< Extent of block-wide A|B tiles in value_t along the K-axis
int _ThreadItemsY, ///< Height in rows of a thread tile in C
int _ThreadItemsX, ///< Width in columns of a thread tile in C
bool _UseDoubleScratchTiles, ///< Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
grid_raster_strategy::kind_t _RasterStrategy> ///< Strategy for enumerating \p block_task within an input matrix
struct block_task_policy
{
enum
{
/// Height in rows of a block-wide tile in matrix C
BlockItemsY = _BlockItemsY,
/// Width in columns of a block-wide tile in matrix C
BlockItemsX = _BlockItemsX,
/// Height in rows of a thread tile in C
ThreadItemsY = _ThreadItemsY,
/// Width in columns of a thread tile in C
ThreadItemsX = _ThreadItemsX,
/// Extent of block-wide A|B tiles in value_t along the K-axis
BlockItemsK = _BlockItemsK,
/// Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
UseDoubleScratchTiles = _UseDoubleScratchTiles,
/// Number of threads in each thread block (blockDim.x)
BlockThreads = divide_assert<
(BlockItemsY * BlockItemsX),
(ThreadItemsY * ThreadItemsX)>::value,
};
/// Strategy for enumerating \p block_task within an input matrix
static const grid_raster_strategy::kind_t RasterStrategy = _RasterStrategy;
};
/******************************************************************************
* block_task
******************************************************************************/
/**
* \brief A block-wide task abstraction for computing device-wide GEMM
*
* Each thread_block is assigned a unique tile of output matrix C to compute by
* consuming the corresponding stripes of the input matrices A and B.
*/
template <
typename block_task_policy_t, ///< Parameterization of block_task_policy
typename value_t, ///< Multiplicand value type (matrices A and B)
typename accum_t, ///< Accumulator value type (matrix C and scalars)
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
int LdgAlignA, ///< Alignment (in bytes) for A operand
matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
int LdgAlignB, ///< Alignment (in bytes) for B operand
typename epilogue_op_t, ///< Epilogue operation applied to GEMM
int LdgAlignC, ///< Alignment (in bytes) for C operand
bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
>
struct block_task
{
//-------------------------------------------------------------------------
// Constants and types
//-------------------------------------------------------------------------
enum
{
/// Number of threads in each thread block (blockDim.x)
BlockThreads = block_task_policy_t::BlockThreads,
/// Extent of thread tile in value_t along M-axis
ThreadItemsY = block_task_policy_t::ThreadItemsY,
/// Extent of thread tile in value_t along N-axis
ThreadItemsX = block_task_policy_t::ThreadItemsX,
};
/// Accumulator type
typedef thread_accumulator<
ThreadItemsY,
ThreadItemsX,
value_t,
accum_t>
thread_accumulator_t;
/// Dot-product vector type along the K-axis (e.g, uchar4 when using IDP4A)
typedef typename thread_accumulator_t::dp_vector_t dp_vector_t;
enum
{
/// Whether this is a small, latency-bound tile
IsSmallTile = (ThreadItemsY < 4) && (ThreadItemsX < 4),
/// Number of value_t in dp_vector_t
DpVectorItems = divide_assert<sizeof(dp_vector_t), sizeof(value_t)>::value,
/// Extent of block-wide C-tile in accum_t (and A-tiles in value_t) along M-axis (height)
BlockItemsY = block_task_policy_t::BlockItemsY,
/// Extent of block-wide C-tile in accum_t (and B-tiles in value_t) along N-axis (width)
BlockItemsX = block_task_policy_t::BlockItemsX,
/// Extent of block-wide A|B tiles in value_t along the K-axis
BlockItemsK = block_task_policy_t::BlockItemsK,
/// Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
UseDoubleScratchTiles = block_task_policy_t::UseDoubleScratchTiles,
/// Extent of block-wide A|B tiles in dp_vector_t along the K-axis
BlockDpVectorsK = divide_assert<BlockItemsK, DpVectorItems>::value,
/// Number of dp_vector_t along M-axis that can be read in a single LDS from the shared A-tile (up to 128b if more than one value_t)
LdsVectorDpVectorsA = __NV_STD_MIN(
ThreadItemsY,
__NV_STD_MAX(1, (128 / (__NV_STD_MAX(sizeof(dp_vector_t), sizeof(accum_t)) * 8)))),
/// Number of dp_vector_t along N-axis that can be read in a single LDS from the shared B-tile (up to 128b if more than one value_t)
LdsVectorDpVectorsB = __NV_STD_MIN(
ThreadItemsX,
__NV_STD_MAX(1, (128 / (__NV_STD_MAX(sizeof(dp_vector_t), sizeof(accum_t)) * 8)))),
/// Number of strip-mined LDS vector reads from shared A-tile
ThreadLdsVectorsA = divide_assert<ThreadItemsY, LdsVectorDpVectorsA>::value,
/// Number of strip-mined LDS vector reads from shared B-tile
ThreadLdsVectorsB = divide_assert<ThreadItemsX, LdsVectorDpVectorsB>::value,
/// Number of elements in one LDG/STG vector of C-tile
ThreadLdgVectorSizeC = __NV_STD_MIN(LdgAlignC, 16) / (sizeof(accum_t)),
/// Number of threads in warp
WarpThreads = 32,
/// Extent of warp in threads along the M-axis
WarpThreadsY = (BlockItemsY > BlockItemsX) ? 8 : 4,
/// Extent of warp in threads along the N-axis
WarpThreadsX = divide_assert<WarpThreads, WarpThreadsY>::value,
/// Extent of warp-wide tile in items along the M-axis
WarpItemsY = WarpThreadsY * ThreadItemsY,
/// Extent of warp-wide tile in items along the N-axis
WarpItemsX = WarpThreadsX * ThreadItemsX,
/// Extent of block in warps along M-axis
BlockWarpsY = divide_assert<BlockItemsY, WarpItemsY>::value,
/// Extent of block in warps along N-axis
BlockWarpsX = divide_assert<BlockItemsX, WarpItemsX>::value,
};
/// Load-from-shared data movement type for A-tile, coarsened by LdsVectorDpVectorsA
typedef io_vector<dp_vector_t, LdsVectorDpVectorsA> lds_vector_a_t;
/// Load-from-shared data movement type for B-tile, coarsened by LdsVectorDpVectorsB
typedef io_vector<dp_vector_t, LdsVectorDpVectorsB> lds_vector_b_t;
/// Thread block rasterization helper type
typedef grid_raster<
BlockItemsY,
BlockItemsX,
TransformA,
TransformB,
block_task_policy_t::RasterStrategy>
grid_raster_t;
/// Tile loader type for matrix A
typedef block_loader<
BlockThreads, // BlockThreads
BlockDpVectorsK, // BlockDpVectorsK
BlockItemsY, // BlockItemsL
value_t, // value_t
LdgAlignA, // MatrixAlignBytes
AllowRaggedTiles, // AllowRaggedTiles
dp_vector_t, // dp_vector_t
(TransformA == matrix_transform_t::NonTranspose) ? // LoadAlgorithm
load_algorithm::CongruousCopy :
load_algorithm::CrosswiseCopy>
block_loader_a_t;
/// Tile loader type for matrix B
typedef block_loader<
BlockThreads, // BlockThreads
BlockDpVectorsK, // BlockDpVectorsK
BlockItemsX, // BlockItemsL
value_t, // value_t
LdgAlignB, // MatrixAlignBytes
AllowRaggedTiles, // AllowRaggedTiles
dp_vector_t, // dp_vector_t
(TransformB == matrix_transform_t::NonTranspose) ? // LoadAlgorithm
load_algorithm::CrosswiseCopy :
load_algorithm::CongruousCopy>
block_loader_b_t;
enum
{
/// Number of value_t to pad the end of each row of the shared A-tile
PadItemsA = (TransformA == matrix_transform_t::NonTranspose) ?
__NV_STD_MAX(LdsVectorDpVectorsA, block_loader_a_t::AlignmentDpVectorsL) :
LdsVectorDpVectorsA,
/// Number of value_t to pad the end of each row of the shared B-tile
PadItemsB = (TransformB == matrix_transform_t::NonTranspose) ?
LdsVectorDpVectorsB :
__NV_STD_MAX(LdsVectorDpVectorsB, block_loader_b_t::AlignmentDpVectorsL),
};
/// Shared memory layout for a prefetch page
struct page_storage_t
{
/// Tile of A
dp_vector_t __align__(16) block_a[BlockDpVectorsK][BlockItemsY + PadItemsA];
/// Tile of B
dp_vector_t __align__(16) block_b[BlockDpVectorsK][BlockItemsX + PadItemsB];
};
/// Shared memory layout for scratch storage
struct scratch_storage_t
{
/// Prefetch pages
page_storage_t pages[UseDoubleScratchTiles ? 2 : 1];
/// Accumulator shared scratch
typename thread_accumulator_t::scratch_storage_t accum_scratch;
};
//-------------------------------------------------------------------------
// Assert assumptions
//-------------------------------------------------------------------------
// Ensure we have at least two unrolled innermost loop iterations (one to prefetch
// the next global tile and then one to prefetch the first strip of it from shared)
static_assert ((BlockDpVectorsK >= 2), "BlockDpVectorsK must be >= 2.");
//-------------------------------------------------------------------------
// Members
//-------------------------------------------------------------------------
/// Scratch storage reference
scratch_storage_t *scratch;
/// Which page of scratch tiles we're currently reading from
int page_idx;
/// Pointer to matrix C
accum_t *d_c;
/// Epilogue operation applied to update matrix C
epilogue_op_t epilogue_op;
/// Matrix height in rows of trans_op(A) and C
int dim_m;
/// Matrix width in columns of trans_op(B) and C
int dim_n;
/// Control for inter-block k-splitting
k_split_control k_split;
/// Thread block's base value_t coordinates (m, n) in matrix C
grid_raster_t grid_raster;
/// Thread block's current coordinate (k) within A|B matrices
int block_item_coords_k;
/// Thread block's ending coordinate (k) within A|B matrices (one-past)
int block_end_item_k;
/// Warp's coordinates (x, y) in thread block
int2 block_warp_coords;
/// Thread's coordinates (x, y) in warp
int2 warp_thread_coords;
/// Thread's base item offset within strip of A tile
int thread_strip_offset_a;
/// Thread's base item offset within strip of B tile
int thread_strip_offset_b;
/// Thread's active-k/prefetch-k slices from shared A tile
lds_vector_a_t local_slices_a[2][ThreadLdsVectorsA];
/// Thread's active-k/prefetch-k slices from shared B tile
lds_vector_b_t local_slices_b[2][ThreadLdsVectorsB];
/// A tile loader
block_loader_a_t loader_a;
/// B tile loader
block_loader_b_t loader_b;
/// C tile accumulator
thread_accumulator_t accumulator;
//-------------------------------------------------------------------------
// Coordinate system helpers
//-------------------------------------------------------------------------
/// Compute the warp's coordinates (x, y) in thread block
inline __device__
int2 warp_coords()
{
int warp_id = threadIdx.x / WarpThreads;
return make_int2(
warp_id % BlockWarpsX,
warp_id / BlockWarpsX);
}
/// Compute the thread's lane-coordinates (x, y) in warp
inline __device__
int2 thread_coords()
{
int lane_id = threadIdx.x % WarpThreads;
// Maxwell+ mapping of threads within a 2D warp for maximal LDS bandwidth
return make_int2(
lane_id / WarpThreadsY,
lane_id % WarpThreadsY);
}
//-------------------------------------------------------------------------
// Constructor API
//-------------------------------------------------------------------------
/// Constructor
inline __device__
block_task(
scratch_storage_t *scratch,
value_t *d_a,
value_t *d_b,
accum_t *d_c,
epilogue_op_t epilogue_op,
int dim_m,
int dim_n,
int dim_k,
k_split_control k_split)
:
scratch(scratch),
page_idx(0),
d_c(d_c),
epilogue_op(epilogue_op),
dim_m(dim_m),
dim_n(dim_n),
k_split(k_split),
block_item_coords_k(k_split.block_begin_item_k()),
block_end_item_k(k_split.block_end_item_k(dim_k)),
block_warp_coords(warp_coords()),
warp_thread_coords(thread_coords()),
thread_strip_offset_a((warp_thread_coords.y * LdsVectorDpVectorsA) + (block_warp_coords.y * WarpItemsY)),
thread_strip_offset_b((warp_thread_coords.x * LdsVectorDpVectorsB) + (block_warp_coords.x * WarpItemsX)),
loader_a(
d_a, // d_matrix
dim_m, // matrix_values_l
(TransformA == matrix_transform_t::NonTranspose) ? dim_m : 1, // matrix_values_stride_k
(TransformA == matrix_transform_t::NonTranspose) ? 1 : dim_k, // matrix_values_stride_l
make_int2( // block_begin_item_coords
grid_raster.block_item_coords.y,
block_item_coords_k),
block_end_item_k), // block_end_item_k
loader_b(
d_b, // d_matrix
dim_n, // matrix_values_l
(TransformB == matrix_transform_t::NonTranspose) ? 1 : dim_n, // matrix_values_stride_k
(TransformB == matrix_transform_t::NonTranspose) ? dim_k : 1, // matrix_values_stride_l
make_int2( // block_begin_item_coords
grid_raster.block_item_coords.x,
block_item_coords_k),
block_end_item_k), // block_end_item_k
accumulator(scratch->accum_scratch)
{}
//-------------------------------------------------------------------------
// Prefetching utility methods
//-------------------------------------------------------------------------
/**
* Request the calling thread's slices of the shared tiles at depth \p tile_offset_k
*/
inline __device__ void request_local_prefetch(
lds_vector_a_t (&slice_a)[ThreadLdsVectorsA], ///< Slice from A
lds_vector_b_t (&slice_b)[ThreadLdsVectorsB], ///< Slice from B
int tile_offset_k)
{
// Load B strip
for (int i = 0; i < ThreadLdsVectorsB; ++i)
{
slice_b[i].load(
&scratch->pages[page_idx].block_b[tile_offset_k][thread_strip_offset_b + (i * WarpThreadsX * LdsVectorDpVectorsB)]);
}
// Load A strip
for (int i = 0; i < ThreadLdsVectorsA; ++i)
{
slice_a[i].load(
&scratch->pages[page_idx].block_a[tile_offset_k][thread_strip_offset_a + (i * WarpThreadsY * LdsVectorDpVectorsA)]);
}
}
//-------------------------------------------------------------------------
// Epilogue
//-------------------------------------------------------------------------
/**
* Performs the GEMM epilogue:
* - Applies the scalar multipliers and addends to the accumulators
* - Write the result to the output matrix
*/
inline __device__ void epilogue()
{
// Wait for predecessor thread block(s) to produce block-wide tile of
// exclsuive partial-sums
k_split.wait();
// Configure epilogue as to whether the thread block is a secondary
// accumulator in an inter-block k-splitting scheme
if (k_split.is_secondary_accumulator())
epilogue_op.set_secondary_accumulator();
// Whether the addend from C needs loading
bool must_init_addend = epilogue_op.must_init_addend();
#pragma unroll
for (int x = 0; x < ThreadItemsX; ++x)
{
#pragma unroll
for (int y = 0; y < ThreadItemsY; y += LdsVectorDpVectorsA)
{
int thread_strip_b = x / LdsVectorDpVectorsB;
int thread_strip_a = y / LdsVectorDpVectorsA;
int thread_item_coords_tile_x = thread_strip_offset_b + (thread_strip_b * WarpThreadsX * LdsVectorDpVectorsB) + (x % LdsVectorDpVectorsB);
int thread_item_coords_tile_y = thread_strip_offset_a + (thread_strip_a * WarpThreadsY * LdsVectorDpVectorsA) + (y % LdsVectorDpVectorsA);
int c_idx = (grid_raster.block_item_coords.x + thread_item_coords_tile_x) * dim_m +
grid_raster.block_item_coords.y + thread_item_coords_tile_y;
accum_t *my_c = d_c + c_idx;
#pragma unroll
for (int i = 0; i < LdsVectorDpVectorsA; ++i)
{
accum_t c_slice = accum_t(0);
accum_t *c_ptr = my_c + i;
if ((grid_raster.block_item_coords.x + thread_item_coords_tile_x) < dim_n &&
(grid_raster.block_item_coords.y + thread_item_coords_tile_y + i) < dim_m)
{
if (must_init_addend)
{
ldg_cg(c_slice, c_ptr);
}
c_slice = epilogue_op(accumulator.get(x, y + i), c_slice, c_idx + i);
stg_cg(c_ptr, c_slice);
}
}
}
}
// Signal k-split successor thread_block that we have produced our block-wide
// tile of inclusive partial-sums
k_split.signal();
}
//-------------------------------------------------------------------------
// Tile consumption
//-------------------------------------------------------------------------
/**
* Consume a tile of A and B each
*/
template <bool DoGlobalPrefetch>
inline __device__
void consume_tile()
{
// Unroll BlockDpVectorsK iterations of outer-product accumulations
#pragma unroll
for (int tile_offset_k = 0; tile_offset_k < BlockDpVectorsK; tile_offset_k += 1)
{
// Last strip commits global prefetch for next tile
if ((tile_offset_k == BlockDpVectorsK - 1) && DoGlobalPrefetch)
{
// If not using two pages of scratch tiles, protect the above prefetch loads from the committing writes below
if (!UseDoubleScratchTiles)
__syncthreads();
// If using two pages of scratch tiles, switch to next page before writing
if (UseDoubleScratchTiles)
{
page_idx = (page_idx ? 0 : 1);
}
// Commit global prefetch data to scratch page
loader_a.commit(scratch->pages[page_idx].block_a);
loader_b.commit(scratch->pages[page_idx].block_b);
__syncthreads();
}
// Request local prefetch for next strip
request_local_prefetch(
local_slices_a[(tile_offset_k + 1) % 2],
local_slices_b[(tile_offset_k + 1) % 2],
(tile_offset_k + 1) % BlockDpVectorsK);
// Request global prefetch for next tile on first strip
if ((tile_offset_k == 0) && DoGlobalPrefetch)
{
loader_b.request();
loader_b.next();
loader_a.request();
loader_a.next();
}
// Cast strip-mined loads to contiguous array of dp_vector_t
typedef dp_vector_t thread_tile_a_t[ThreadLdsVectorsA * LdsVectorDpVectorsA];
typedef dp_vector_t thread_tile_b_t[ThreadLdsVectorsB * LdsVectorDpVectorsB];
thread_tile_a_t &thread_tile_a = reinterpret_cast<thread_tile_a_t&>(local_slices_a[(tile_offset_k) % 2]);
thread_tile_b_t &thread_tile_b = reinterpret_cast<thread_tile_b_t&>(local_slices_b[(tile_offset_k) % 2]);
// Accumulate this dp-stripe product
accumulator.multiply_accumulate(thread_tile_a, thread_tile_b);
}
}
//-------------------------------------------------------------------------
// GEMM API
//-------------------------------------------------------------------------
/**
* Compute GEMM
*/
inline __device__
void run()
{
// Quit if the thread block is fully out-of-bounds
if (grid_raster.is_block_oob(dim_m, dim_n))
{
asm volatile("exit;");
}
// Request global prefetch of first tile
loader_a.request();
loader_a.next();
loader_b.request();
loader_b.next();
// Commit global prefetch of first tile to shared memory
loader_a.commit(scratch->pages[page_idx].block_a);
loader_b.commit(scratch->pages[page_idx].block_b);
// Advance to next A,B tiles in K-axis
block_item_coords_k += BlockItemsK;
// Synchronize shared tiles and prepared accumulator
__syncthreads();
// Initialize thread's slice of accumulators
accumulator.init();
// Request first iteration of local prefetch strips
request_local_prefetch(
local_slices_a[0],
local_slices_b[0],
0);
//
// Main loop
//
// Consume tiles in A and B along the K-axis (all but last tile)
#pragma unroll 1
while (block_item_coords_k < block_end_item_k)
{
consume_tile<true>();
// Advance to next A,B tiles in K-axis
block_item_coords_k += BlockItemsK;
}
// Consume last tile
consume_tile<false>();
//
// Eplilogue
//
epilogue();
}
};
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,759 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/**
* \file
* A block-wide task abstraction for computing device-wide GEMM
*/
#pragma once
// Compiler guard conditional to avoid compilation errors on versions of CUDA that
// do not support the WMMA API.
#if defined (WMMA)
#include <stdint.h>
#include "../util/util.h"
#include "grid_raster.h"
#include "block_loader.h"
#include "block_loader_wmma.h"
#include "wmma_accumulator.h"
namespace cutlass {
namespace gemm {
/******************************************************************************
* block_task_wmma_policy
******************************************************************************/
/**
* \brief Parameterizable tuning policy for block-wide WMMA GEMM tasks
*
* Once parameterized, \p block_task_policy provides the member constant
* \p BlockThreads indicating to the required thread block size
*/
template <
int _BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
int _BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
int _BlockItemsK, ///< Extent of block-wide A|B tiles in value_t along the K-axis
int _WarpItemsY, ///< Height in rows of a Warp tile's accumulators
int _WarpItemsX, ///< Width in columns of a Warp tile's accumulators
int _WmmaItemsY, ///< Height in rows of a discrete WMMA block's accumulators
int _WmmaItemsX, ///< Width in columns of a discrete WMMA block's accumulators
int _WmmaItemsK, ///< Depth of each discrete WMMA block
bool _UseDoubleScratchTiles, ///< Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
grid_raster_strategy::kind_t _RasterStrategy> ///< Strategy for enumerating \p block_task within an input matrix
struct block_task_wmma_policy
{
/// Strategy for enumerating \p block_task within an input matrix
static const grid_raster_strategy::kind_t RasterStrategy = _RasterStrategy;
enum
{
/// Height in rows of a block-wide tile in matrix C
BlockItemsY = _BlockItemsY,
/// Width in columns of a block-wide tile in matrix C
BlockItemsX = _BlockItemsX,
/// Extent of block-wide A|B tiles in value_t along the K-axis
BlockItemsK = _BlockItemsK,
/// Height in rows of a Warp tile's accumulators
WarpItemsX = _WarpItemsX,
/// Width in columns of a Warp tile's accumulators
WarpItemsY = _WarpItemsY,
/// Width in columns of a discrete WMMA block's accumulators
WmmaItemsX = _WmmaItemsX,
/// Height in rows of a discrete WMMA block's accumulators
WmmaItemsY = _WmmaItemsY,
/// Depth of each discrete WMMA block
WmmaItemsK = _WmmaItemsK,
/// Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
UseDoubleScratchTiles = _UseDoubleScratchTiles,
//
// Derived quantities
//
/// Machine warp size
WarpThreads = 32,
/// Number of WMMA operations in the height dimension
WmmaBlocksY = divide_assert<WarpItemsY, WmmaItemsY>::value,
/// Number of WMMA operations in the height dimension
WmmaBlocksX = divide_assert<WarpItemsX, WmmaItemsX>::value,
/// Number of warps in each thread block
BlockWarps = divide_assert<BlockItemsY * BlockItemsX, WarpItemsX * WarpItemsY>::value,
/// Number of threads in each thread block (blockDim.x)
BlockThreads = BlockWarps * WarpThreads,
};
};
/******************************************************************************
* block_task_wmma
******************************************************************************/
/**
* \brief A block-wide task abstraction for computing device-wide GEMM
*
* Each thread_block is assigned a unique tile of output matrix C to compute by
* consuming the corresponding stripes of the input matrices A and B.
*/
template <
typename block_task_policy_t, ///< Parameterization of block_task_policy
typename value_t, ///< Multiplicand value type (matrices A and B)
typename accum_t, ///< Accumulator value type (matrix C and scalars)
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
int LdgAlignA, ///< Alignment (in bytes) for A operand
matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
int LdgAlignB, ///< Alignment (in bytes) for B operand
typename epilogue_op_t, ///< Epilogue operation to update matrix C
int LdgAlignC, ///< Alignment (in bytes) for C operand
bool AllowRaggedTiles ///< Whether the input matrix's dimensions need not be an even-multiple of the block-wide tile dimensions
>
struct block_task_wmma
{
//-------------------------------------------------------------------------
// Constants and types
//-------------------------------------------------------------------------
enum
{
/// Number of threads in each thread block (blockDim.x)
BlockThreads = block_task_policy_t::BlockThreads,
/// Extent of block-wide C-tile in accum_t (and A-tiles in value_t) along M-axis (height)
BlockItemsY = block_task_policy_t::BlockItemsY,
/// Extent of block-wide C-tile in accum_t (and B-tiles in value_t) along N-axis (width)
BlockItemsX = block_task_policy_t::BlockItemsX,
/// Extent of block-wide A|B tiles in value_t along the K-axis
BlockItemsK = block_task_policy_t::BlockItemsK,
/// Extent of warp C-tile in accum_t (and A-tiles in value_t) along M-axis (height)
WarpItemsY = block_task_policy_t::WarpItemsY,
/// Extent of warp C-tile in accum_t (and B-tiles in value_t) along N-axis (width)
WarpItemsX = block_task_policy_t::WarpItemsX,
/// Extent of warp C-tile in accum_t (and A-tiles in value_t) along M-axis (height)
WmmaItemsY = block_task_policy_t::WmmaItemsY,
/// Extent of warp C-tile in accum_t (and B-tiles in value_t) along N-axis (width)
WmmaItemsX = block_task_policy_t::WmmaItemsX,
/// Extent of warp-wide A|B-tiles in value_t along K-axis
WmmaItemsK = block_task_policy_t::WmmaItemsK,
/// Whether to halve synchronization overhead at the expense of doubled shared memory and addressing overhead
UseDoubleScratchTiles = block_task_policy_t::UseDoubleScratchTiles,
/// Number of threads in warp
WarpThreads = block_task_policy_t::WarpThreads,
/// Number of warps participating
BlockWarps = block_task_policy_t::BlockWarps,
/// Extent of block in warps along M-axis
BlockWarpsY = divide_assert<BlockItemsY, WarpItemsY>::value,
/// Extent of block in warps along N-axis
BlockWarpsX = divide_assert<BlockItemsX, WarpItemsX>::value,
/// Number of MMA unrolls
WmmaUnrollCount = divide_assert<BlockItemsK, WmmaItemsK>::value,
/// True if the A matrix layout is column major (K is the strided dimension)
IsLayoutCongruousA = (TransformA == matrix_transform_t::NonTranspose),
/// True if the B matrix layout is row mayor (K is the strided dimension)
IsLayoutCongruousB = (TransformB == matrix_transform_t::Transpose),
};
/// WMMA may support unique types for A and B, so plan ahead for this
typedef value_t value_a_t;
/// WMMA may support unique types for A and B, so plan ahead for this
typedef value_t value_b_t;
/// WMMA accumulator type
typedef wmma_accumulator<
WarpItemsY,
WarpItemsX,
WmmaItemsY,
WmmaItemsX,
WmmaItemsK,
value_a_t,
value_b_t,
accum_t,
TransformA,
TransformB>
accumulator_t;
/// Thread block rasterization helper type
typedef grid_raster<
BlockItemsY,
BlockItemsX,
TransformA,
TransformB,
block_task_policy_t::RasterStrategy>
grid_raster_t;
/// Tile loader type for matrix A
typedef block_loader_wmma<
IsLayoutCongruousA,
BlockThreads,
(IsLayoutCongruousA ? BlockItemsY : BlockItemsK),
(IsLayoutCongruousA ? BlockItemsK : BlockItemsY),
value_a_t,
LdgAlignA,
AllowRaggedTiles>
block_loader_a_t;
/// Tile loader type for matrix A
typedef block_loader_wmma<
IsLayoutCongruousB,
BlockThreads,
(IsLayoutCongruousB ? BlockItemsX : BlockItemsK),
(IsLayoutCongruousB ? BlockItemsK : BlockItemsX),
value_b_t,
LdgAlignB,
AllowRaggedTiles>
block_loader_b_t;
/// Type alias for matrix A fragment type
typedef typename accumulator_t::fragment_a_t fragment_a_t;
/// Type alias for matrix B fragment type
typedef typename accumulator_t::fragment_b_t fragment_b_t;
enum
{
/// Number of fragments from A matrix
WmmaBlocksY = accumulator_t::WmmaBlocksY,
/// Number of fragments from B matrix
WmmaBlocksX = accumulator_t::WmmaBlocksX,
/// Number of value_t to pad the outer dimension of the shared A-tile
PadItemsA = 16,
/// Number of value_t to pad the outer dimension of the shared B-tile
PadItemsB = 16,
/// Leading dimension of A matrix tile
LdmSmemA = (IsLayoutCongruousA ? BlockItemsY: BlockItemsK) + PadItemsA,
/// Leading dimension of A matrix tile
StridedSmemA = (IsLayoutCongruousA ? BlockItemsK : BlockItemsY ),
/// Leading dimension of B matrix tile
LdmSmemB = (IsLayoutCongruousB? BlockItemsX : BlockItemsK) + PadItemsB,
StridedSmemB = (IsLayoutCongruousB ? BlockItemsK : BlockItemsX),
};
/// Shared memory layout for a prefetch page
struct page_storage_t
{
/// Tile of A
value_a_t __align__(16) block_a[StridedSmemA][LdmSmemA];
/// Tile of B
value_b_t __align__(16) block_b[StridedSmemB][LdmSmemB];
};
/// Shared memory layout for scratch storage
struct scratch_storage_t
{
union
{
/// Prefetch pages
uninitialized<page_storage_t> pages[UseDoubleScratchTiles ? 2 : 1];
/// Scratch storage for warps
accum_t epilogue[BlockWarps][WmmaItemsX * WmmaItemsY];
};
};
//-------------------------------------------------------------------------
// Assert assumptions
//-------------------------------------------------------------------------
// Ensure we have at least two unrolled innermost loop iterations (one to prefetch
// the next global tile and then one to prefetch the first strip of it from shared)
static_assert ((BlockItemsK >= 2), "BlockItemsK must be >= 2.");
//-------------------------------------------------------------------------
// Members
//-------------------------------------------------------------------------
/// Scratch storage reference
scratch_storage_t *scratch;
/// Which page of scratch tiles we're currently reading from
int page_idx;
/// Pointer to matrix C
accum_t *d_c;
/// Epilogue operation applied to update matrix C
epilogue_op_t epilogue_op;
/// Matrix height in rows of trans_op(A) and C
int dim_m;
/// Matrix width in columns of trans_op(B) and C
int dim_n;
/// Control for inter-block k-splitting
k_split_control k_split;
/// Thread block's base value_t coordinates (m, n) in matrix C
grid_raster_t grid_raster;
/// Thread block's current coordinate (k) within A|B matrices
int block_item_coords_k;
/// Thread block's ending coordinate (k) within A|B matrices (one-past)
int block_end_item_k;
/// Warp's coordinates (x, y) in thread block
int2 block_warp_item_coords;
/// A tile loader
block_loader_a_t loader_a;
/// B tile loader
block_loader_b_t loader_b;
/// Thread's active-k/prefetch-k slices from shared A tile
fragment_a_t local_slices_a[2][WmmaBlocksY];
/// Thread's active-k/prefetch-k slices from shared B tile
fragment_b_t local_slices_b[2][WmmaBlocksX];
/// Accumulator tile
accumulator_t accumulator;
//-------------------------------------------------------------------------
// Coordinate system helpers
//-------------------------------------------------------------------------
/// Compute the warp's item-coordinates (x, y) in thread block
inline __device__
int2 warp_item_coords()
{
int warp_id = threadIdx.x / WarpThreads;
return make_int2(
(warp_id / BlockWarpsY) * WarpItemsX,
(warp_id % BlockWarpsY) * WarpItemsY);
}
/// Compute the thread block's base item-coordinates in matrix A
inline __device__
int2 a_block_item_coords()
{
if (TransformA == matrix_transform_t::NonTranspose)
{
return make_int2(grid_raster.block_item_coords.y, block_item_coords_k);
}
else
{
return make_int2(block_item_coords_k, grid_raster.block_item_coords.y);
}
}
/// Compute the thread block's base item-coordinates in matrix B
inline __device__
int2 b_block_item_coords()
{
if (TransformB == matrix_transform_t::Transpose)
{
return make_int2(grid_raster.block_item_coords.x, block_item_coords_k);
}
else
{
return make_int2(block_item_coords_k, grid_raster.block_item_coords.x);
}
}
//-------------------------------------------------------------------------
// Constructor API
//-------------------------------------------------------------------------
/// Constructor
inline __device__
block_task_wmma(
scratch_storage_t *scratch,
value_t *d_a,
value_t *d_b,
accum_t *d_c,
epilogue_op_t epilogue_op,
int dim_m,
int dim_n,
int dim_k,
k_split_control k_split)
:
scratch(scratch),
page_idx(0),
d_c(d_c),
epilogue_op(epilogue_op),
dim_m(dim_m),
dim_n(dim_n),
k_split(k_split),
block_item_coords_k(k_split.block_begin_item_k()),
block_end_item_k(k_split.block_end_item_k(dim_k)),
block_warp_item_coords(warp_item_coords()),
loader_a(
reinterpret_cast<value_a_t const *>(d_a),
(IsLayoutCongruousA ? dim_m : block_end_item_k),
(IsLayoutCongruousA ? 0 : block_item_coords_k),
(IsLayoutCongruousA ? block_end_item_k : dim_m),
(IsLayoutCongruousA ? dim_m : dim_k),
(IsLayoutCongruousA ? block_item_coords_k : 0),
a_block_item_coords()),
loader_b(
reinterpret_cast<value_b_t const *>(d_b),
(IsLayoutCongruousB ? dim_n : block_end_item_k),
(IsLayoutCongruousB ? 0 : block_item_coords_k),
(IsLayoutCongruousB ? block_end_item_k : dim_n),
(IsLayoutCongruousB ? dim_n : dim_k),
(IsLayoutCongruousB ? block_item_coords_k : 0),
b_block_item_coords())
{}
//-------------------------------------------------------------------------
// Prefetching utility methods
//-------------------------------------------------------------------------
/**
* Request the calling thread's slices of the shared tiles at depth \p tile_offset_k
*/
inline __device__ void request_local_prefetch(
fragment_a_t local_slices_a[WmmaBlocksY], ///< Slice from A
fragment_b_t local_slices_b[WmmaBlocksX], ///< Slice from B
int tile_offset_k)
{
value_b_t const *smem_A_base = &scratch->pages[page_idx].alias().block_a[0][0];
value_b_t const *smem_B_base = &scratch->pages[page_idx].alias().block_b[0][0];
int constexpr kstride_a = (IsLayoutCongruousA ? LdmSmemA : 1);
int constexpr lstride_a = (IsLayoutCongruousA ? 1 : LdmSmemA);
int constexpr kstride_b = (IsLayoutCongruousB ? LdmSmemB : 1);
int constexpr lstride_b = (IsLayoutCongruousB ? 1 : LdmSmemB);
// Load B strip
#pragma unroll
for (int i = 0; i < WmmaBlocksX; ++i)
{
value_b_t const *smem_B_ptr =
&smem_B_base[tile_offset_k * kstride_b + (block_warp_item_coords.x + WmmaItemsX * i) * lstride_b];
nvcuda::wmma::load_matrix_sync(local_slices_b[i], smem_B_ptr, LdmSmemB);
}
// Load A strip
#pragma unroll
for (int i = 0; i < WmmaBlocksY; ++i)
{
value_a_t const *smem_A_ptr =
&smem_A_base[tile_offset_k * kstride_a + (block_warp_item_coords.y + WmmaItemsY * i) * lstride_a];
nvcuda::wmma::load_matrix_sync(local_slices_a[i], smem_A_ptr, LdmSmemA);
}
}
//-------------------------------------------------------------------------
// Epilogue
//-------------------------------------------------------------------------
/**
* Performs the GEMM epilogue:
* - Applies the scalar multipliers and addends to the accumulators
* - Write the result to the output matrix
*/
inline __device__ void epilogue()
{
// Wait for predecessor thread block(s) to produce partial-sums
k_split.wait();
// Configure epilogue as to whether the thread block is a secondary
// accumulator in an inter-block k-splitting scheme
if (k_split.is_secondary_accumulator())
epilogue_op.set_secondary_accumulator();
// Whether or not the addend from C needs loading
bool must_init_addend = epilogue_op.must_init_addend();
int warp_base_x = grid_raster.block_item_coords.x + block_warp_item_coords.x;
int warp_base_y = grid_raster.block_item_coords.y + block_warp_item_coords.y;
int constexpr SmemStride = WmmaItemsY;
int warp_id = threadIdx.x / 32;
// Compute shape of one accumulator read/modify/write operation
int constexpr ItemsY = (WmmaItemsY);
int constexpr ItemsX = (32 / ItemsY);
int constexpr IterationsX = WmmaItemsX / ItemsX;
// Compute a rasterization of warp lanes across the WMMA tile.
int lane_id = (threadIdx.x % 32);
int lane_read_x = (lane_id / ItemsY);
int lane_read_y = (lane_id % ItemsY);
accum_t *smem_scratch = scratch->epilogue[warp_id];
accum_t const *smem_read_ptr = smem_scratch + lane_read_y + lane_read_x * SmemStride;
#pragma unroll
for (int xb = 0; xb < WmmaBlocksX; ++xb)
{
#pragma unroll
for (int yb = 0; yb < WmmaBlocksY; ++yb)
{
// Store accumulator tile to SMEM
nvcuda::wmma::store_matrix_sync(
smem_scratch,
accumulator.accumulators[xb][yb],
SmemStride,
matrix_layout<matrix_transform_t::NonTranspose>::kind);
// Synchronize threads within the warp
__syncthreads();
// Compute lane coordinates so that each thread efficiently accesses SMEM.
int c_x = (warp_base_x + (xb) * WmmaItemsX + lane_read_x);
int c_y = (warp_base_y + (yb) * WmmaItemsY + lane_read_y);
// Compute guard predicate by comparing against problem dimensions.
bool pred = c_y < dim_m;
// Compute output pointer from lane coordinates
int c_index = c_x * dim_m + c_y;
accum_t *c_ptr = reinterpret_cast<accum_t *>(d_c) + c_x * dim_m + c_y;
// Iterate over columns of output tile. Load from SMEM, compute epilogue operation,
// and stream output to global memory
#pragma unroll
for (int item_x = 0; item_x < IterationsX; ++item_x)
{
accum_t accum = smem_read_ptr[item_x * ItemsX * SmemStride];
accum_t c_element = 0;
// Filter against problem dimensions as the warp iterates across the columns of
// output.
pred = (pred && ((c_x + item_x * ItemsX) < dim_n));
if (must_init_addend && pred)
{
// NB: inline PTX to utilize strong operations for inter-block synchronization.
// The following is equivalent to:
//
// c_element = c_ptr[0];
asm volatile ("ld.global.cg.f32 %0, [%1];\n" : "=f"(c_element) : "l"(c_ptr));
}
c_element = epilogue_op(accum, c_element, c_index);
if (pred)
{
// NB: inline PTX to utilize strong operations for inter-block synchronization.
// The following is equivalent to:
//
// c_ptr[0] = c_element;
asm volatile ("st.global.cg.f32 [%0], %1;\n" : : "l"(c_ptr), "f"(c_element));
}
// Increment output pointer
c_ptr += dim_m * ItemsX;
c_index += dim_m * ItemsX;
}
__syncthreads();
}
}
// Signal k-split successor thread_block
k_split.signal();
}
//-------------------------------------------------------------------------
// Tile consumption
//-------------------------------------------------------------------------
/**
* Consume a tile of A and B each
*/
template <bool DoGlobalPrefetch>
inline __device__
void consume_tile()
{
// Request global prefetch for next tile on first strip
if (DoGlobalPrefetch)
{
loader_b.request();
loader_b.next();
loader_a.request();
loader_a.next();
}
// Unroll BlockDpVectorsK iterations of outer-product accumulations
#pragma unroll
for (int iteration = 0; iteration < WmmaUnrollCount; ++iteration)
{
int tile_offset_k = iteration * WmmaItemsK;
// Active load-from-shared index
int active_lds_idx = __NV_STD_MIN(WmmaUnrollCount - 1, (iteration) % 2);
// Next load-from-shared index
int next_lds_idx = __NV_STD_MIN(WmmaUnrollCount - 1, (iteration + 1) % 2);
// The last unrolled iteration commits the global fetches
if ((iteration == WmmaUnrollCount - 1) && DoGlobalPrefetch)
{
// If not using two pages of scratch tiles, protect the above prefetch loads from
// the committing writes below
if (!UseDoubleScratchTiles)
{
__syncthreads();
}
else
{
page_idx = (page_idx ? 0 : 1);
}
// Commit global prefetch data to scratch page
loader_a.template commit<LdmSmemA>(&scratch->pages[page_idx].alias().block_a[0][0]);
loader_b.template commit<LdmSmemB>(&scratch->pages[page_idx].alias().block_b[0][0]);
__syncthreads();
}
// Accumulate this dp-stripe product
accumulator.multiply_accumulate(
local_slices_a[active_lds_idx],
local_slices_b[active_lds_idx]);
// Request local prefetch for next strip
request_local_prefetch(
local_slices_a[next_lds_idx],
local_slices_b[next_lds_idx],
(tile_offset_k + WmmaItemsK) % BlockItemsK);
}
}
//-------------------------------------------------------------------------
// GEMM API
//-------------------------------------------------------------------------
/**
* Compute GEMM
*/
inline __device__
void run()
{
// Quit if the thread block is fully out-of-bounds
if (grid_raster.is_block_oob(dim_m, dim_n))
{
asm volatile("exit;");
}
// Request global prefetch of first tile
loader_a.request();
loader_a.next();
loader_b.request();
loader_b.next();
// Commit global prefetch of first tile to shared memory
loader_a.template commit<LdmSmemA>(&scratch->pages[page_idx].alias().block_a[0][0]);
loader_b.template commit<LdmSmemB>(&scratch->pages[page_idx].alias().block_b[0][0]);
// Advance to next A,B tiles in K-axis
block_item_coords_k += BlockItemsK;
// Synchronize shared tiles and prepared accumulator
__syncthreads();
// Initialize thread's slice of accumulators
accumulator.init();
// Request first iteration of local prefetch strips
request_local_prefetch(
local_slices_a[0],
local_slices_b[0],
0);
//
// Main loop
//
// Consume tiles in A and B along the K-axis (all but last tile)
#pragma unroll 1
while (block_item_coords_k < block_end_item_k)
{
consume_tile<true>();
// Advance to next A,B tiles in K-axis
block_item_coords_k += BlockItemsK;
}
consume_tile<false>();
//
// Eplilogue
//
// prevent overwriting SMEM until all warps have finished loading data
__syncthreads();
// store accumulator tile to global memory
epilogue();
}
};
} // namespace gemm
} // namespace cutlass
#endif

534
cutlass/gemm/dispatch.h Normal file
View File

@@ -0,0 +1,534 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* GEMM kernel entrypoint and dispatch stub
*/
#include <stdint.h>
#include "../util/util.h"
#include "block_task.h"
#include "block_task_wmma.h"
#include "grid_raster.h"
#include "dispatch_policies.h"
#include "k_split_control.h"
namespace cutlass {
namespace gemm {
/******************************************************************************
* param_pack
******************************************************************************/
/**
* Parameter-pack structure
*
* Kernel launch latency is reduced when kernel arguments are wrapped into
* a single parameter
*/
template <
typename value_t,
typename accum_t,
typename epilogue_op_t>
struct param_pack
{
int m; ///< Height in rows of op(A) and C
int n; ///< Width in columns of op(B) and C
int k; ///< Width in columns of op(A) and height in rows of op(B)
k_split_control k_split; ///< Abstraction for controlling inter-block k-splitting
value_t *d_a; ///< Pointer to matrix A array values
value_t *d_b; ///< Pointer to matrix B array values
accum_t *d_c; ///< Pointer to matrix C array values
epilogue_op_t epilogue_op;
param_pack(
int m, ///< Height in rows of op(A) and C
int n, ///< Width in columns of op(B) and C
int k, ///< Width in columns of op(A) and height in rows of op(B)
k_split_control k_split, ///< Abstraction for controlling inter-block k-splitting
epilogue_op_t op, ///< Epilogue operation to update matrix C
value_t *d_a, ///< Pointer to matrix A array values
value_t *d_b, ///< Pointer to matrix B array values
accum_t *d_c) ///< Pointer to matrix C array values
:
m(m),
n(n),
k(k),
k_split(k_split),
epilogue_op(op),
d_a(d_a),
d_b(d_b),
d_c(d_c)
{}
};
/******************************************************************************
* Conditionally select the appropriate GEMM threadblock task
******************************************************************************/
/// Conditional selection for block task
template <
math_operation_class_t math_op, ///<
typename block_task_policy_t, ///< Parameterization of block_task_policy
typename value_t, ///< Multiplicand value type (matrices A and B)
typename accum_t, ///< Accumulator value type (matrix C and scalars)
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
int LdgAlignA, ///< Alignment (in bytes) for A operand
matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
int LdgAlignB, ///< Alignment (in bytes) for B operand
typename epilogue_op_t, ///< Epilogue operation applied to GEMM
int LdgAlignC, ///< Alignment (in bytes) for C operand
bool AllowRaggedTiles ///< Whether GEMM supports matrix sizes other than multiple of BlockItems{XY}
>
struct gemm_block_task;
/// Scalar math operations
template <
typename block_task_policy_t, ///< Parameterization of block_task_policy
typename value_t, ///< Multiplicand value type (matrices A and B)
typename accum_t, ///< Accumulator value type (matrix C and scalars)
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
int LdgAlignA, ///< Alignment (in bytes) for A operand
matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
int LdgAlignB, ///< Alignment (in bytes) for B operand
typename epilogue_op_t, ///< Epilogue operation applied to GEMM
int LdgAlignC, ///< Alignment (in bytes) for C operand
bool AllowRaggedTiles ///< Whether GEMM supports matrix sizes other than multiple of BlockItems{XY}
>
struct gemm_block_task<
math_operation_class_t::scalar,
block_task_policy_t,
value_t,
accum_t,
TransformA,
LdgAlignA,
TransformB,
LdgAlignB,
epilogue_op_t,
LdgAlignC,
AllowRaggedTiles
>
{
// Parameterize task type
typedef block_task<
block_task_policy_t,
value_t,
accum_t,
TransformA,
LdgAlignA,
TransformB,
LdgAlignB,
epilogue_op_t,
LdgAlignC,
AllowRaggedTiles> type;
};
/// Matrix math operations
template <
typename block_task_policy_t, ///< Parameterization of block_task_policy
typename value_t, ///< Multiplicand value type (matrices A and B)
typename accum_t, ///< Accumulator value type (matrix C and scalars)
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
int LdgAlignA, ///< Alignment (in bytes) for A operand
matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
int LdgAlignB, ///< Alignment (in bytes) for B operand
typename epilogue_op_t, ///< Epilogue operation applied to GEMM
int LdgAlignC, ///< Alignment (in bytes) for C operand
bool AllowRaggedTiles ///< Whether GEMM supports matrix sizes other than multiple of BlockItems{XY}
>
struct gemm_block_task<
math_operation_class_t::matrix,
block_task_policy_t,
value_t,
accum_t,
TransformA,
LdgAlignA,
TransformB,
LdgAlignB,
epilogue_op_t,
LdgAlignC,
AllowRaggedTiles>
{
#if defined(WMMA) // conditional compilation with WMMA headers
// Parameterize task type
typedef block_task_wmma<
block_task_policy_t,
value_t,
accum_t,
TransformA,
LdgAlignA,
TransformB,
LdgAlignB,
epilogue_op_t,
LdgAlignC,
AllowRaggedTiles> type;
#endif
};
/******************************************************************************
* GEMM kernel entrypoint
******************************************************************************/
/**
* GEMM kernel
*
* NB: Not sure why NVVM is doing stuff with "__launch_bounds__" instead of just
* passing it along to PTXAS, but it is currently resulting in less optimal codegen
*/
template <
math_operation_class_t math_op, ///< Indicates which class of math operation to select
typename block_task_policy_t, ///< Parameterization of block_task_policy
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
int LdgAlignA, ///< Alignment of A matrix elements in bytes
matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
int LdgAlignB, ///< Alignment of B matrix elements in bytes
typename value_t, ///< Multiplicand value type (matrices A and B)
typename accum_t, ///< Accumulator value type (matrix C and scalars)
typename epilogue_op_t, ///< Epilogue operation applied to update matrix C
int LdgAlignC, ///< Alignment of C elements in bytes
bool AllowRaggedTiles> ///< Boolean to indicate whether AllowRaggedTiles handling is enabled
__global__ void kernel(param_pack<value_t, accum_t, epilogue_op_t> pack)
{
// Parameterize task type
typedef typename gemm_block_task<
math_op,
block_task_policy_t,
value_t,
accum_t,
TransformA,
LdgAlignA,
TransformB,
LdgAlignB,
epilogue_op_t,
LdgAlignC,
AllowRaggedTiles>::type block_task_t;
// Declare statically-allocated shared storage
__shared__ typename block_task_t::scratch_storage_t smem;
// Construct and run the task
block_task_t(
&smem,
pack.d_a,
pack.d_b,
pack.d_c,
pack.epilogue_op,
pack.m,
pack.n,
pack.k,
pack.k_split).run();
}
/******************************************************************************
* Launch configuration description returned to the caller
******************************************************************************/
/// Return details about the launch configuration to the caller
struct launch_configuration
{
//
// Data members
//
/// cudaError_t resulting from grid launch
cudaError_t result;
/// Extent of a thread block's partition along the GEMM K-axis
int split_k;
/// Kernel grid extents in thread blocks
dim3 grid;
/// Thread block extents in threads
dim3 block;
//
// Methods
//
/// Constructor
launch_configuration():
result(cudaSuccess),
split_k(0),
grid(0, 0, 0),
block(0, 0, 0) {
}
/// Conversion from cudaError_t
launch_configuration(cudaError_t result):
result(result),
split_k(1),
grid(0, 0, 0),
block(0, 0, 0) {
}
/// Launch configuration for Cutlass kernels
launch_configuration(
cudaError_t result,
int split_k,
dim3 grid,
dim3 block
):
result(result),
split_k(split_k),
grid(grid),
block(block) {
}
};
/******************************************************************************
* Dispatch stub
******************************************************************************/
/**
* GEMM dispatch stub
*
* This function also serves as the autotuning entrypoint to evaluate different
* tuning parameterizations of kernel.
*/
template <
math_operation_class_t math_op, ///< Indicates which class of math operation to select
typename block_task_policy_t, ///< Parameterization of block_task_policy
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
int LdgAlignA, ///< Alignment of A matrix elements in bytes
matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
int LdgAlignB, ///< Alignment of B matrix elements in bytes
typename value_t, ///< Multiplicand value type (matrices A and B)
typename accum_t, ///< Accumulator value type (matrix C and scalars)
typename epilogue_op_t, ///< Epilogue operation
int LdgAlignC, ///< Alignment of C matrix elements in bytes
bool AllowRaggedTiles, ///< Boolean to indicate whether AllowRaggedTiles handling is enabled
typename kernel_ptr_t> ///< GEMM kernel function pointer type
launch_configuration dispatch(
kernel_ptr_t kernel_ptr, ///< GEMM kernel function pointer
int m, ///< Height in rows of op(A) and C
int n, ///< Width in columns of op(B) and C
int k, ///< Width in columns of op(A) and height in rows of op(B)
epilogue_op_t epilogue_op, ///< Epilogue operation to update matrix C
value_t *d_a, ///< Device pointer to matrix A array values
value_t *d_b, ///< Device pointer to matrix B array values
accum_t *d_c, ///< Device pointer to matrix C array values
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous = true) ///< Whether or not to synchronize the stream after every kernel launch
/// to check for errors. Also causes launch configurations to be printed
/// to the console if DEBUG is defined. Default is \p false.
{
// Thread block rasterization type
typedef grid_raster<
block_task_policy_t::BlockItemsY,
block_task_policy_t::BlockItemsX,
TransformA,
TransformB,
block_task_policy_t::RasterStrategy>
grid_raster_t;
launch_configuration config;
// Compute block dims
config.block = dim3(block_task_policy_t::BlockThreads);
// Compute shared memory
int dynamic_smem_bytes = 0;
// Compute occupancy
int max_sm_occupancy;
if (CUDA_PERROR_DEBUG(config.result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&max_sm_occupancy,
kernel_ptr,
config.block.x * config.block.y,
dynamic_smem_bytes)))
{
return config;
}
// Compute grid extents
config.grid = grid_raster_t::grid_dims(m, n);
// Get SM count
int sm_count;
if (CUDA_PERROR_DEBUG(config.result = get_sm_count(sm_count)))
return config;
// Get k-split flag storage (TODO: make a pool)
int *d_flags;
if (CUDA_PERROR_DEBUG(config.result = cudaGetSymbolAddress((void**) &d_flags, d_flags_split_k)))
return config;
// Construct k-split coordinator
k_split_control k_split(
d_flags,
sm_count,
max_sm_occupancy,
k,
block_task_policy_t::BlockItemsK,
config.block,
config.grid); // in,out
config.split_k = k_split.split_k;
// Log kernel configuration
if (debug_synchronous)
{
// Compute tiling efficiency
float block_tiling_efficiency = float(block_task_policy_t::BlockItemsY * block_task_policy_t::BlockItemsX) /
float(block_task_policy_t::BlockItemsY + block_task_policy_t::BlockItemsX);
float tiling_efficiency = block_tiling_efficiency;
float wave_efficiency = k_split.get_wave_efficiency(
sm_count, max_sm_occupancy, config.block, config.grid);
CUDA_LOG_DEBUG("Final wave_efficiency %.4f, tiling_efficiency %.4f\n",
wave_efficiency, tiling_efficiency);
CUDA_LOG_DEBUG("Invoking kernel<<<(%d, %d, %d), (%d.y,%d.x), %d, %lld>>>(), %d SM occupancy, %d split_k\n",
config.grid.x, config.grid.y, config.grid.z,
config.block.y, config.block.x,
dynamic_smem_bytes,
(long long) stream,
max_sm_occupancy,
k_split.split_k);
}
// Construct parameter-pack
param_pack<value_t, accum_t, epilogue_op_t> pack(
m,
n,
k,
k_split,
epilogue_op,
d_a,
d_b,
d_c);
// Prepare k-split coordinator
if (CUDA_PERROR_DEBUG(config.result = k_split.prepare(stream, debug_synchronous)))
{
return config;
}
// Invoke kernel
kernel_ptr<<< config.grid, config.block, dynamic_smem_bytes, stream >>>(pack);
// Check for failure to launch
if (CUDA_PERROR_DEBUG(config.result = cudaPeekAtLastError()))
return config;
// Sync the stream if specified to flush runtime errors
if (debug_synchronous && (CUDA_PERROR_DEBUG(config.result = cudaStreamSynchronize(stream))))
return config;
return config;
}
/******************************************************************************
* GEMM
******************************************************************************/
/**
* Computes gemm on device matrices
*/
template <
tiling_strategy::kind_t TilingStrategy, ///< Tile-sizing classification
math_operation_class_t math_op, ///< Indicates which class of math operation to select
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
int LdgAlignA, ///< Alignment (in bytes) of A operand
matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
int LdgAlignB, ///< Alignment (in bytes) of B operand
typename value_t, ///< Multiplicand value type (matrices A and B)
typename accum_t, ///< Accumulator value type (matrix C and scalars)
typename epilogue_op_t, ///< Epilogue operation to update matrix C
int LdgAlignC> ///< Alignment (in bytes) of C operand
launch_configuration device_gemm(
int m, ///< Height in rows of op(A) and C
int n, ///< Width in columns of op(B) and C
int k, ///< Width in columns of op(A) and height in rows of op(B)
epilogue_op_t epilogue_op, ///< Epilogue operation to update matrix C
value_t *d_a, ///< Device pointer to matrix A array values
value_t *d_b, ///< Device pointer to matrix B array values
accum_t *d_c, ///< Device pointer to matrix C array values
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to
/// check for errors. Also causes launch configurations to be printed to
/// the console if DEBUG is defined. Default is \p false.
{
// Parameterize an task policy type
// (TODO: use a policy dispatch mechanism based upon SM version)
typedef gemm_policy<value_t, accum_t, TransformA, TransformB, TilingStrategy> block_task_policy_t;
// AllowRaggedTiles-tile check
if ((m % block_task_policy_t::BlockItemsY != 0) ||
(n % block_task_policy_t::BlockItemsX != 0) ||
(k % block_task_policy_t::BlockItemsK != 0))
{
// Needs ragged tile-handling
static const bool AllowRaggedTiles = true;
return dispatch<math_op, block_task_policy_t, TransformA, LdgAlignA, TransformB, LdgAlignB, value_t, accum_t, epilogue_op_t, LdgAlignC, AllowRaggedTiles>(
kernel<math_op,block_task_policy_t, TransformA, LdgAlignA, TransformB, LdgAlignB, value_t, accum_t, epilogue_op_t, LdgAlignC, AllowRaggedTiles>,
m,
n,
k,
epilogue_op,
d_a,
d_b,
d_c,
stream,
debug_synchronous);
}
else
{
// Does not need ragged tile-handling
static const bool AllowRaggedTiles = false;
return dispatch<math_op, block_task_policy_t, TransformA, LdgAlignA, TransformB, LdgAlignB, value_t, accum_t, epilogue_op_t, LdgAlignC, AllowRaggedTiles>(
kernel<math_op,block_task_policy_t, TransformA, LdgAlignA, TransformB, LdgAlignB, value_t, accum_t, epilogue_op_t, LdgAlignC, AllowRaggedTiles>,
m,
n,
k,
epilogue_op,
d_a,
d_b,
d_c,
stream,
debug_synchronous);
}
}
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,653 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Architecture-specific GEMM block_task policies
*/
#include <stdint.h>
#include "../util/util.h"
#include "block_task.h"
#include "grid_raster.h"
namespace cutlass {
namespace gemm {
/******************************************************************************
* tiling_strategy
******************************************************************************/
/**
* Enumeration of tile-sizing granularities
*/
struct tiling_strategy : printable_t
{
/// \brief Enumerants
enum kind_t
{
Unknown,
Small,
Medium,
Large,
Tall,
Wide,
Huge,
};
/// Enumerant value
kind_t kind;
/// Default constructor
tiling_strategy() : kind(Unknown) {}
/// Copy constructor
tiling_strategy(const kind_t &other_kind) : kind(other_kind) {}
/// Cast to kind_t
operator kind_t() const { return kind; }
/// Returns the instance as a string
__host__ __device__ inline
char const* to_string() const
{
switch (kind)
{
case Small: return "small";
case Medium: return "medium";
case Large: return "large";
case Tall: return "tall";
case Wide: return "wide";
case Huge: return "huge";
case Unknown:
default: return "unknown";
}
}
/// Insert the formatted instance into the output stream
void print(std::ostream& out) const { out << to_string(); }
};
/******************************************************************************
* GEMM
******************************************************************************/
/**
* GEMM task policy specialization for sgemm
*/
template <
typename value_t,
typename accum_t,
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
tiling_strategy::kind_t TilingStrategy> ///< Tile-sizing classification
struct gemm_policy;
/******************************************************************************
* SGEMM
******************************************************************************/
/**
* GEMM task policy specialization for small sgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<float, float, TransformA, TransformB, tiling_strategy::Small> :
block_task_policy<
16, // _BlockItemsY
16, // _BlockItemsX
16, // _BlockItemsK
2, // _ThreadItemsY
2, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for medium sgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<float, float, TransformA, TransformB, tiling_strategy::Medium> :
block_task_policy<
32, // _BlockItemsY
32, // _BlockItemsX
8, // _BlockItemsK
4, // _ThreadItemsY
4, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for large sgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<float, float, TransformA, TransformB, tiling_strategy::Large> :
block_task_policy<
64, // _BlockItemsY
64, // _BlockItemsX
8, // _BlockItemsK
8, // _ThreadItemsY
8, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for tall sgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<float, float, TransformA, TransformB, tiling_strategy::Tall> :
block_task_policy<
128, // _BlockItemsY
32, // _BlockItemsX
8, // _BlockItemsK
8, // _ThreadItemsY
4, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for wide sgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<float, float, TransformA, TransformB, tiling_strategy::Wide> :
block_task_policy<
32, // _BlockItemsY
128, // _BlockItemsX
8, // _BlockItemsK
4, // _ThreadItemsY
8, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for huge sgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<float, float, TransformA, TransformB, tiling_strategy::Huge> :
block_task_policy<
128, // _BlockItemsY
128, // _BlockItemsX
8, // _BlockItemsK
8, // _ThreadItemsY
8, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/******************************************************************************
* DGEMM
******************************************************************************/
/**
* GEMM task policy specialization for small dgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<double, double, TransformA, TransformB, tiling_strategy::Small> :
block_task_policy<
16, // _BlockItemsY
16, // _BlockItemsX
16, // _BlockItemsK
2, // _ThreadItemsY
2, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for medium dgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<double, double, TransformA, TransformB, tiling_strategy::Medium> :
block_task_policy<
32, // _BlockItemsY
32, // _BlockItemsX
16, // _BlockItemsK
4, // _ThreadItemsY
4, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for large dgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<double, double, TransformA, TransformB, tiling_strategy::Large> :
block_task_policy<
64, // _BlockItemsY
64, // _BlockItemsX
8, // _BlockItemsK
4, // _ThreadItemsY
4, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for tall dgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<double, double, TransformA, TransformB, tiling_strategy::Tall> :
block_task_policy<
128, // _BlockItemsY
32, // _BlockItemsX
8, // _BlockItemsK
8, // _ThreadItemsY
4, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for wide dgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<double, double, TransformA, TransformB, tiling_strategy::Wide> :
block_task_policy<
32, // _BlockItemsY
128, // _BlockItemsX
8, // _BlockItemsK
4, // _ThreadItemsY
8, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for huge dgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<double, double, TransformA, TransformB, tiling_strategy::Huge> :
block_task_policy<
64, // _BlockItemsY
128, // _BlockItemsX
8, // _BlockItemsK
8, // _ThreadItemsY
8, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/******************************************************************************
* HGEMM
******************************************************************************/
/**
* GEMM task policy specialization for small hgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Small> :
block_task_policy<
32, // _BlockItemsY
32, // _BlockItemsX
8, // _BlockItemsK
4, // _ThreadItemsY
4, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for medium hgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Medium> :
block_task_policy<
32, // _BlockItemsY
32, // _BlockItemsX
16, // _BlockItemsK
8, // _ThreadItemsY
4, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for large hgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Large> :
block_task_policy<
64, // _BlockItemsY
64, // _BlockItemsX
8, // _BlockItemsK
16, // _ThreadItemsY
8, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for tall hgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Tall> :
block_task_policy<
128, // _BlockItemsY
32, // _BlockItemsX
8, // _BlockItemsK
16, // _ThreadItemsY
4, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for wide hgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Wide> :
block_task_policy<
32, // _BlockItemsY
128, // _BlockItemsX
8, // _BlockItemsK
8, // _ThreadItemsY
8, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for huge hgemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<__half, __half, TransformA, TransformB, tiling_strategy::Huge> :
block_task_policy<
128, // _BlockItemsY
128, // _BlockItemsX
8, // _BlockItemsK
16, // _ThreadItemsY
8, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/******************************************************************************
* IGEMM
******************************************************************************/
/**
* GEMM task policy specialization for small igemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<int8_t, int32_t, TransformA, TransformB, tiling_strategy::Small> :
block_task_policy<
16, // _BlockItemsY
32, // _BlockItemsX
32, // _BlockItemsK
4, // _ThreadItemsY
4, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for medium igemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<int8_t, int32_t, TransformA, TransformB, tiling_strategy::Medium> :
block_task_policy<
32, // _BlockItemsY
32, // _BlockItemsX
32, // _BlockItemsK
4, // _ThreadItemsY
4, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for large igemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<int8_t, int32_t, TransformA, TransformB, tiling_strategy::Large> :
block_task_policy<
64, // _BlockItemsY
64, // _BlockItemsX
32, // _BlockItemsK
8, // _ThreadItemsY
4, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for large igemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<int8_t, int32_t, TransformA, TransformB, tiling_strategy::Tall> :
block_task_policy<
128, // _BlockItemsY
64, // _BlockItemsX
64, // _BlockItemsK
8, // _ThreadItemsY
4, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for large igemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<int8_t, int32_t, TransformA, TransformB, tiling_strategy::Wide> :
block_task_policy<
64, // _BlockItemsY
128, // _BlockItemsX
64, // _BlockItemsK
4, // _ThreadItemsY
8, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/**
* GEMM task policy specialization for huge igemm
*/
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<int8_t, int32_t, TransformA, TransformB, tiling_strategy::Huge> :
block_task_policy<
128, // _BlockItemsY
128, // _BlockItemsX
32, // _BlockItemsK
8, // _ThreadItemsY
8, // _ThreadItemsX
false, // _UseDoubleScratchTiles
grid_raster_strategy::Default> // _RasterStrategy
{};
/******************************************************************************
* WMMA GEMM
******************************************************************************/
// WMMA is a preview feature in CUDA. Conditionally enable wmma_gemm policies.
#if defined(WMMA)
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<half, float, TransformA, TransformB, tiling_strategy::Small> :
gemm::block_task_wmma_policy<
16, // _BlockItemsY
16, // _BlockItemsX
16, // _BlockItemsK
16, // _WarpItemsY
16, // _WarpItemsX
16, // _WmmaItemsY
16, // _WmmaItemsX
16, // _WmmaItemsK
false, // _UseDoubleScratchTiles
gemm::grid_raster_strategy::Default> // _RasterStrategy
{};
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy<half, float, TransformA, TransformB, tiling_strategy::Medium> :
gemm::block_task_wmma_policy<
32, // _BlockItemsY
32, // _BlockItemsX
32, // _BlockItemsK
32, // _WarpItemsY
32, // _WarpItemsX
16, // _WmmaItemsY
16, // _WmmaItemsX
16, // _WmmaItemsK
false, // _UseDoubleScratchTiles
gemm::grid_raster_strategy::Default> // _RasterStrategy
{};
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy< half, float, TransformA, TransformB, tiling_strategy::Large> :
gemm::block_task_wmma_policy<
64, // _BlockItemsY
64, // _BlockItemsX
32, // _BlockItemsK
32, // _WarpItemsY
64, // _WarpItemsX
16, // _WmmaItemsY
16, // _WmmaItemsX
16, // _WmmaItemsK
false, // _UseDoubleScratchTiles
gemm::grid_raster_strategy::Default> // _RasterStrategy
{};
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy< half, float, TransformA, TransformB, tiling_strategy::Tall> :
gemm::block_task_wmma_policy<
128, // _BlockItemsY
64, // _BlockItemsX
64, // _BlockItemsK
32, // _WarpItemsY
64, // _WarpItemsX
16, // _WmmaItemsY
16, // _WmmaItemsX
16, // _WmmaItemsK
false, // _UseDoubleScratchTiles
gemm::grid_raster_strategy::Default> // _RasterStrategy
{};
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy< half, float, TransformA, TransformB, tiling_strategy::Wide> :
gemm::block_task_wmma_policy<
64, // _BlockItemsY
128, // _BlockItemsX
64, // _BlockItemsK
32, // _WarpItemsY
64, // _WarpItemsX
16, // _WmmaItemsY
16, // _WmmaItemsX
16, // _WmmaItemsK
false, // _UseDoubleScratchTiles
gemm::grid_raster_strategy::Default> // _RasterStrategy
{};
template <
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB> ///< Transformation op for matrix B
struct gemm_policy< half, float, TransformA, TransformB, tiling_strategy::Huge> :
gemm::block_task_wmma_policy<
128, // _BlockItemsY
128, // _BlockItemsX
64, // _BlockItemsK
32, // _WarpItemsY
64, // _WarpItemsX
16, // _WmmaItemsY
16, // _WmmaItemsX
16, // _WmmaItemsK
false, // _UseDoubleScratchTiles
gemm::grid_raster_strategy::Default> // _RasterStrategy
{};
#endif
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,215 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Abstraction for exposing architecture-specific "dot-product-accumulate"
* ISA operations
*/
#include <stdint.h>
#include "../util/util.h"
namespace cutlass {
namespace gemm {
/******************************************************************************
* dp_accummulate
******************************************************************************/
/**
* \brief Abstraction for exposing architecture-specific "dot-product-accumulate"
* ISA operations
*
* Given two K-component vectors a and b having type value_t[K] and an addend c
* of type accum_t, the "dot-product-accumulate" of type accum_t is computed
* as d = x[0]*y[0] + x[1]*y[1] + ... + x[K-1]*y[K-1] + c.
*
* We use the notation "dpK" to connote a K-component dot-product-accumulate.
* For example, "dp1" is a simple multiply-add.
*
* For given pairing of value_t and accum_t types, the corresponding
* dp_accummulate class will:
*
* - Define the member-type dp_vector_t as the appropriate K-component vector
* type needed to leverage architecture-specific "dot-product accumulate"
* ISA operations.
* - Implement the corresponding dot-product operation between two dp_vector_t
* inputs a and b.
*
*/
template <
typename value_t, ///< Component value type
typename accum_t> ///< Accumulator value type
struct dp_accummulate;
/// Default "dp1" dot-product-accumulate traits specialization for value_t->accum_t
template <
typename value_t, ///< Component value type
typename accum_t> ///< Accumulator value type
struct dp_accummulate
{
/// Single-component "dp1" dot-product vector type
typedef value_t dp_vector_t;
/// Compute "dp1" float->float
inline __device__
static void mad(
float &d,
const float &a,
const float &b,
const float &c)
{
asm volatile ( "fma.rn.f32 %0, %1, %2, %3;\n"
: "=f"(d) : "f"(a), "f"(b), "f"(c));
}
/// Compute "dp1" double->double
inline __device__
static void mad(
double &d,
const double &a,
const double &b,
const double &c)
{
asm volatile ("fma.rn.f64 %0, %1, %2, %3;\n"
: "=d"(d) : "d"(a), "d"(b), "d"(c));
}
/// Compute "dp1" int16_t->int32_t
inline __device__
static void mad(
int32_t &d,
const int16_t &a,
const int16_t &b,
const int32_t &c)
{
asm volatile ("mad.wide.s16 %0, %1, %2, %3;\n"
: "=r"(d) : "h"(a), "h"(b), "r"(c));
}
/// Compute "dp1" uint16_t->uint32_t
inline __device__
static void mad(
uint32_t &d,
const uint16_t &a,
const uint16_t &b,
const uint32_t &c)
{
asm volatile ("mad.wide.u16 %0, %1, %2, %3;\n"
: "=r"(d) : "h"(a), "h"(b), "r"(c));
}
/// Compute "dp1" int32_t->int32_t
inline __device__
static void mad(
int32_t &d,
const int32_t &a,
const int32_t &b,
const int32_t &c)
{
asm volatile ("mad.lo.s32 %0, %1, %2, %3;\n"
: "=r"(d) : "r"(a), "r"(b), "r"(c));
}
/// Compute "dp1" uint32_t->uint32_t
inline __device__
static void mad(
uint32_t &d,
const uint32_t &a,
const uint32_t &b,
const uint32_t &c)
{
asm volatile ("mad.lo.u32 %0, %1, %2, %3;\n"
: "=r"(d) : "r"(a), "r"(b), "r"(c));
}
};
#if (CUTLASS_ARCH >= 610) // Specializations only enabled for Pascal SM610+
/// "dp4" dot-product-accumulate traits specialization for int8_t->int32_t
template <>
struct dp_accummulate<
int8_t, ///< Component value type
int32_t> ///< Accumulator value type
{
/// Four-component signed "idp4"
typedef int32_t dp_vector_t;
/// Compute "dp4" int16_t->int32_t
inline __device__
static void mad(
int32_t &d,
const int32_t &a,
const int32_t &b,
const int32_t &c)
{
asm volatile ( "dp4a.s32.s32 %0, %1, %2, %3;\n"
: "=r"(d) : "r"(a), "r"(b), "r"(c));
}
};
/// "dp4" dot-product-accumulate traits specialization for uint8_t->uint32_t
template <>
struct dp_accummulate<
uint8_t, ///< Component value type
uint32_t> ///< Accumulator value type
{
/// Four-component unsigned "idp4"
typedef uint32_t dp_vector_t;
/// Compute "dp4" uint16_t->uint32_t
inline __device__
static void mad(
uint32_t &d,
const uint32_t &a,
const uint32_t &b,
const uint32_t &c)
{
asm volatile ( "dp4a.u32.u32 %0, %1, %2, %3;\n"
: "=r"(d) : "r"(a), "r"(b), "r"(c));
}
};
#endif // Specializations only enabled for Pascal SM610+
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,96 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Epilogue operation to compute final output
*/
namespace cutlass {
namespace gemm {
//// Used by GEMM to compute the final result C <= alpha * accumulator + beta * C
template <
typename accum_t,
typename output_t,
typename scalar_t
>
class blas_scaled_epilogue
{
public:
scalar_t alpha;
scalar_t beta;
inline __device__ __host__
blas_scaled_epilogue(
scalar_t alpha,
scalar_t beta)
:
alpha(alpha),
beta(beta)
{}
/// Epilogue operator
inline __device__ __host__
output_t operator()(
accum_t accumulator,
output_t c,
size_t idx) const
{
return output_t(alpha * scalar_t(accumulator) + beta * scalar_t(c));
}
/// Epilogue operator
inline __device__ __host__
output_t operator()(
accum_t accumulator,
size_t idx) const
{
return output_t(alpha * scalar_t(accumulator));
}
/**
* Configure epilogue as to whether the thread block is a secondary
* accumulator in an inter-block k-splitting scheme
*/
inline __device__
void set_secondary_accumulator()
{
beta = scalar_t(1);
}
/// Return whether the beta-scaled addend needs initialization
inline __device__
bool must_init_addend()
{
return (beta != scalar_t(0));
}
};
} // namespace gemm
} // namespace cutlass

428
cutlass/gemm/grid_raster.h Normal file
View File

@@ -0,0 +1,428 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Abstraction for enumerating \p block_task within an input matrix
*/
#include <stdint.h>
#include "../util/util.h"
namespace cutlass {
namespace gemm {
/******************************************************************************
* grid_raster_strategy
******************************************************************************/
/**
* \brief Strategies for enumerating \p block_task within an input matrix
*/
struct grid_raster_strategy
{
/// \brief Enumerants
enum kind_t
{
/**
* Default \p block_task assignment (currently ColumnMajor for N*,
* RowMajor for TT, and TiledCohort for TN)
*/
Default,
/**
* Column-major \p block_task assignment
*/
ColumnMajor,
/**
* Row-major \p block_task assignment
*/
RowMajor,
/**
* Two-level \p block_task assignment (both column-major)
*/
TiledCohort,
};
};
/******************************************************************************
* grid_raster
******************************************************************************/
/**
* \brief Abstraction for enumerating \p block_task within an input matrix
*
* NB: This generic class is not directly constructible. Algorithm-specific
* template specializations will provide the API functionality prescribed here.
*/
template <
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
int BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
matrix_transform_t::kind_t TransformB, ///< View transform enumerant for matrix B
grid_raster_strategy::kind_t RasterStrategy> ///< Strategy for enumerating \p block_task within an input matrix
struct grid_raster
{
//-------------------------------------------------------------------------
// Device API
//-------------------------------------------------------------------------
/// Thread block's base item coordinates (x, y) in matrix C
int2 block_item_coords;
/// Constructor
grid_raster();
/// Whether the thread block base coordinates are out-of-bounds for an m*n matrix C
bool is_block_oob(int m, int n);
//-------------------------------------------------------------------------
// Grid launch API
//-------------------------------------------------------------------------
/// Compute the kernel grid extents (in thread blocks) for consuming an m*n matrix C
static dim3 grid_dims(int m, int n);
};
/******************************************************************************
* grid_raster (ColumnMajor specialization)
******************************************************************************/
/**
* \brief Abstraction for enumerating \p block_task within an input matrix
* (ColumnMajor specialization)
*
* Maps thread blocksin column-major fashion
*/
template <
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
int BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
matrix_transform_t::kind_t TransformB> ///< View transform enumerant for matrix B
struct grid_raster<
BlockItemsY,
BlockItemsX,
TransformA,
TransformB,
grid_raster_strategy::ColumnMajor> ///< Strategy for enumerating \p block_task within an input matrix
{
//-------------------------------------------------------------------------
// Device API
//-------------------------------------------------------------------------
/// Thread block's base item coordinates (x, y) in matrix C
int2 block_item_coords;
/// Constructor
inline __device__
grid_raster()
{
// blockDim.x is the fastest changing grid dim on current architectures
block_item_coords = make_int2(
BlockItemsX * blockIdx.y,
BlockItemsY * blockIdx.x);
}
/// Whether the base \p block_item_coords are out-of-bounds for an m*n matrix C
inline __device__
bool is_block_oob(int m, int n)
{
// ColumnMajor never rasterizes fully out-of-bounds thread blocks
return false;
}
//-------------------------------------------------------------------------
// Grid launch API
//-------------------------------------------------------------------------
/// Compute the kernel grid extents (in thread blocks) for consuming an m*n matrix C
inline __host__ __device__
static dim3 grid_dims(int m, int n)
{
// blockDim.x is the fastest changing grid dim on current architectures
return dim3(
(m + BlockItemsY - 1) / BlockItemsY,
(n + BlockItemsX - 1) / BlockItemsX);
}
};
/******************************************************************************
* grid_raster (RowMajor specialization)
******************************************************************************/
/**
* \brief Abstraction for enumerating \p block_task within an input matrix
* (RowMajor specialization)
*
* Enumerates \p block_task in row-major fashion
*/
template <
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
int BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
matrix_transform_t::kind_t TransformB> ///< View transform enumerant for matrix B
struct grid_raster<
BlockItemsY,
BlockItemsX,
TransformA,
TransformB,
grid_raster_strategy::RowMajor> ///< Strategy for enumerating \p block_task within an input matrix
{
//-------------------------------------------------------------------------
// Device API
//-------------------------------------------------------------------------
/// Thread block's base item coordinates (x, y) in matrix C
int2 block_item_coords;
/// Constructor
inline __device__
grid_raster()
{
// blockDim.x is the fastest changing grid dim on current architectures
block_item_coords = make_int2(
BlockItemsX * blockIdx.x,
BlockItemsY * blockIdx.y);
}
/// Whether the base \p block_item_coords are out-of-bounds for an m*n matrix C
inline __device__
bool is_block_oob(int m, int n)
{
// RowMajor never rasterizes fully out-of-bounds thread blocks
return false;
}
//-------------------------------------------------------------------------
// Grid launch API
//-------------------------------------------------------------------------
/// Compute the kernel grid extents (in thread blocks) for consuming an m*n matrix C
inline __host__ __device__
static dim3 grid_dims(int m, int n)
{
// blockDim.x is the fastest changing grid dim on current architectures
return dim3(
(n + BlockItemsX - 1) / BlockItemsX,
(m + BlockItemsY - 1) / BlockItemsY);
}
};
/******************************************************************************
* grid_raster (TiledCohort specialization)
******************************************************************************/
/**
* \brief Abstraction for enumerating \p block_task within an input matrix
* (TiledCohort specialization)
*
* Enumerates \p block_task in column-major fashion across "cohort" tiles (where
* cohorts are CohortBlocksY high and CohortBlocksX wide), and enumerates cohorts
* across the matrix in column-major fashion.
*
* Grid layout:
* - gridDim.y is the height of the grid in cohorts
* - gridDim.x is the width of the grid in cohorts multiplied by the number of
* thread blocks per cohort
*/
template <
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
int BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
matrix_transform_t::kind_t TransformA, ///< View transform enumerant for matrix A
matrix_transform_t::kind_t TransformB> ///< View transform enumerant for matrix B
struct grid_raster<
BlockItemsY,
BlockItemsX,
TransformA,
TransformB,
grid_raster_strategy::TiledCohort> ///< Strategy for enumerating \p block_task within an input matrix
{
enum
{
/// Height in thread blocks of a grid rasterization cohort
CohortBlocksY = 2,
/// Width in thread blocks of a grid rasterization cohort
CohortBlocksX = 2,
/// Number of thread blocks per cohort
BlocksPerCohort = CohortBlocksY * CohortBlocksX,
/// Height in items of a grid rasterization cohort
CohortItemsY = CohortBlocksY * BlockItemsY,
/// Width in items of a grid rasterization cohort
CohortItemsX = CohortBlocksX * BlockItemsX,
};
//-------------------------------------------------------------------------
// Device API
//-------------------------------------------------------------------------
/// Thread block's base item coordinates (x, y) in matrix C
int2 block_item_coords;
/// Constructor
inline __device__
grid_raster()
{
int block_idx_cohort = blockIdx.x % BlocksPerCohort;
int2 cohort_coords_grid = make_int2(
blockIdx.x / BlocksPerCohort,
blockIdx.y);
// Cohort is rastered in column-major order
int2 block_coords_cohort = make_int2(
block_idx_cohort / CohortBlocksY,
block_idx_cohort % CohortBlocksY);
block_item_coords = make_int2(
((cohort_coords_grid.x * CohortBlocksX) + block_coords_cohort.x) * BlockItemsX,
((cohort_coords_grid.y * CohortBlocksY) + block_coords_cohort.y) * BlockItemsY);
}
/// Whether the base \p block_item_coords are out-of-bounds for an m*n matrix C
inline __device__
bool is_block_oob(int m, int n)
{
/// thread blocks within the cohort may be fully out-of-bounds
return (block_item_coords.x >= n) || (block_item_coords.y >= m);
}
//-------------------------------------------------------------------------
// Grid launch API
//-------------------------------------------------------------------------
/// Compute the kernel grid extents (in thread blocks) for consuming an m*n matrix C
inline __host__ __device__
static dim3 grid_dims(int m, int n)
{
// Extents of C matrix in cohorts
int2 grid_cohort_dims = make_int2(
(n + CohortItemsX - 1) / CohortItemsX,
(m + CohortItemsY - 1) / CohortItemsY);
return dim3(
grid_cohort_dims.x * BlocksPerCohort, // gridDim.x is width of grid in cohorts * size of cohort in blocks
grid_cohort_dims.y, // gridDim.y is height of grid in cohorts
1); // gridDim.z is reserved for optional k-splitting
}
};
/******************************************************************************
* grid_raster (Default specializations)
******************************************************************************/
/**
* \brief Abstraction for enumerating \p block_task within an input matrix
* (Default N* specialization)
*
* Maps thread blocksin column-major fashion
*/
template <
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
int BlockItemsX, ///< Width in columns of a block-wide tile in matrix C
matrix_transform_t::kind_t TransformB> ///< View transform enumerant for matrix B
struct grid_raster<
BlockItemsY,
BlockItemsX,
matrix_transform_t::NonTranspose, ///< View transform enumerant for matrix A
TransformB,
grid_raster_strategy::Default> ///< Strategy for enumerating \p block_task within an input matrix
:
grid_raster<
BlockItemsY,
BlockItemsX,
matrix_transform_t::NonTranspose,
TransformB,
grid_raster_strategy::ColumnMajor>
{};
/**
* \brief Abstraction for enumerating \p block_task within an input matrix
* (Default TT specialization)
*
* Maps thread blocksin row-major fashion
*/
template <
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
int BlockItemsX> ///< Width in columns of a block-wide tile in matrix C
struct grid_raster<
BlockItemsY,
BlockItemsX,
matrix_transform_t::Transpose, ///< View transform enumerant for matrix A
matrix_transform_t::Transpose, ///< View transform enumerant for matrix B
grid_raster_strategy::Default> ///< Strategy for enumerating \p block_task within an input matrix
:
grid_raster<
BlockItemsY,
BlockItemsX,
matrix_transform_t::Transpose,
matrix_transform_t::Transpose,
grid_raster_strategy::RowMajor>
{};
/**
* \brief Abstraction for enumerating \p block_task within an input matrix
* (Default TN specialization)
*
* Maps thread blocksin blocked cohorts
*/
template <
int BlockItemsY, ///< Height in rows of a block-wide tile in matrix C
int BlockItemsX> ///< Width in columns of a block-wide tile in matrix C
struct grid_raster<
BlockItemsY,
BlockItemsX,
matrix_transform_t::Transpose, ///< View transform enumerant for matrix A
matrix_transform_t::NonTranspose, ///< View transform enumerant for matrix B
grid_raster_strategy::Default> ///< Strategy for enumerating \p block_task within an input matrix
:
grid_raster<
BlockItemsY,
BlockItemsX,
matrix_transform_t::Transpose,
matrix_transform_t::NonTranspose,
grid_raster_strategy::TiledCohort>
{};
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,302 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Abstraction for coordinating inter-block k-splitting
*/
#include <stdint.h>
#include "../util/util.h"
namespace cutlass {
namespace gemm {
/******************************************************************************
* Storage and initialization
******************************************************************************/
enum
{
NumFlagsSplitK = 4096
};
/**
* Global K-split semaphore flags
*
* TODO: use demand-allocated storage to provide copies for concurrent streams
*/
__device__ int d_flags_split_k[NumFlagsSplitK];
/**
* Preparation kernel for zero-initializing semaphore flags
*/
__global__ void prepare_kernel(int *d_flags_split_k)
{
int tid = (blockIdx.x * blockDim.x) + threadIdx.x;
if (tid < NumFlagsSplitK)
d_flags_split_k[tid] = 0;
}
/******************************************************************************
* k_split_control
******************************************************************************/
/**
* \brief Abstraction for coordinating inter-block k-splitting
*/
struct k_split_control
{
/// Extent of a thread block's partition along the GEMM K-axis
int split_k;
/// Whether or not to use a semaphore for inter-block k-splitting.
bool use_semaphore;
/// Pointer to semaphore
int *d_flags;
//-------------------------------------------------------------------------
// Device API
//-------------------------------------------------------------------------
/**
* Return the thread block's starting coordinate (k) within the
* multiplicand matrices
*/
inline __device__
int block_begin_item_k()
{
return blockIdx.z * split_k;
}
/**
* Return the thread block's ending coordinate (k) within the multiplicand
* matrices (one-past)
*/
inline __device__
int block_end_item_k(int dim_k)
{
int next_start_k = block_begin_item_k() + split_k;
return __NV_STD_MIN(next_start_k, dim_k);
}
/**
* Whether the thread block is a secondary accumulator in an inter-block
* k-splitting scheme
*/
inline __device__
bool is_secondary_accumulator()
{
return (blockIdx.z > 0);
}
/**
* Wait for predecessor thread block(s) to produce the exclusive
* partial-sums for this block-wide tile
*/
inline __device__
void wait()
{
// Wait on semaphore
if ((use_semaphore) && (blockIdx.z > 0))
{
if (threadIdx.x == 0)
{
int bid = (blockIdx.y * gridDim.x) + blockIdx.x;
int hash = bid % NumFlagsSplitK;
int found;
int looking = blockIdx.z;
while (true)
{
asm volatile ("ld.global.cg.u32 %0, [%1];\n" : "=r"(found) : "l"(d_flags + hash));
if (found == looking)
break;
/// Fence to keep load from being hoisted from the loop
__syncwarp(0x00000001);
}
}
__syncthreads();
}
}
/**
* Signal the successor thread_block(s) that the inclusive partial-sums
* from this block-wide tile are available
*/
inline __device__
void signal()
{
if (use_semaphore)
{
__syncthreads();
if (threadIdx.x == 0)
{
int bid = (blockIdx.y * gridDim.x) + blockIdx.x;
int hash = bid % NumFlagsSplitK;
int val = blockIdx.z + 1;
asm volatile ("st.global.cg.u32 [%0], %1;\n" : : "l"(d_flags + hash), "r"(val));
}
}
}
//-------------------------------------------------------------------------
// Grid launch API
//-------------------------------------------------------------------------
/**
* Constructor
*/
inline
k_split_control(
int *d_flags,
int sm_count,
int max_sm_occupancy,
int dim_k,
int block_tile_items_k,
dim3 block_dims,
dim3 &grid_dims) ///< [in,out]
:
d_flags(d_flags),
split_k(dim_k)
{
// Compute wave efficiency
float wave_efficiency = get_wave_efficiency(
sm_count,
max_sm_occupancy,
block_dims,
grid_dims);
// Update split-k if wave efficiency is less than some threshold
if (wave_efficiency < 0.9)
{
int num_threadblocks = grid_dims.x * grid_dims.y * grid_dims.z;
// Ideal number of thread blocks in grid
int ideal_threadblocks = lcm(sm_count, num_threadblocks);
// Desired number of partitions to split K-axis into
int num_partitions = ideal_threadblocks / num_threadblocks;
// Compute new k-split share
int new_split_k = (dim_k + num_partitions - 1) / num_partitions;
// Round split_k share to the nearest block_task_policy_t::BlockItemsK
new_split_k = round_nearest(new_split_k, block_tile_items_k);
// Recompute k-splitting factor with new_split_k
num_partitions = (dim_k + new_split_k - 1) / new_split_k;
// Update grid dims and k if we meet the minimum number of iterations worth the overhead of splitting
int min_iterations_k = 8;
if (((new_split_k / block_tile_items_k) > min_iterations_k) && // We're going to go through at least this many k iterations
(sm_count * max_sm_occupancy < NumFlagsSplitK)) // We have enough semaphore flags allocated
{
grid_dims.z = num_partitions;
split_k = new_split_k;
}
}
use_semaphore = (grid_dims.z > 1);
}
/**
* Initializer
*/
cudaError_t prepare(
cudaStream_t stream, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous) ///< Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console if DEBUG is defined. Default is \p false.
{
cudaError error = cudaSuccess;
if (use_semaphore)
{
int block_threads = 128;
int grid_dims = (NumFlagsSplitK + block_threads - 1) / block_threads;
prepare_kernel<<<grid_dims, block_threads, 0, stream>>>(d_flags);
// Check for failure to launch
if (CUDA_PERROR_DEBUG(error = cudaPeekAtLastError()))
return error;
// Sync the stream if specified to flush runtime errors
if (debug_synchronous && (CUDA_PERROR_DEBUG(error = cudaStreamSynchronize(stream))))
return error;
}
return error;
}
/**
* Compute the efficiency of dispatch wave quantization
*/
float get_wave_efficiency(
int sm_count,
int max_sm_occupancy,
dim3 block_dims,
dim3 grid_dims)
{
// Heuristic for how many warps are needed to saturate an SM for a given
// multiply-accumulate genre. (NB: We could make this more rigorous by
// specializing on data types and SM width)
int saturating_warps_per_sm = 16;
int num_threadblocks = grid_dims.x * grid_dims.y * grid_dims.z;
int threads_per_threadblock = block_dims.x * block_dims.y;
int warps_per_threadblock = threads_per_threadblock / 32;
int saturating_threadblocks_per_sm = (saturating_warps_per_sm + warps_per_threadblock - 1) / warps_per_threadblock;
int saturating_residency = sm_count * saturating_threadblocks_per_sm;
int full_waves = num_threadblocks / saturating_residency;
int remainder_threadblocks = num_threadblocks % saturating_residency;
int total_waves = (remainder_threadblocks == 0) ? full_waves : full_waves + 1;
float last_wave_saturating_efficiency = float(remainder_threadblocks) / saturating_residency;
return (float(full_waves) + last_wave_saturating_efficiency) / total_waves;
}
};
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,461 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Thread-level multiply-accumulate abstraction
*/
#include "../util/util.h"
#include "dp_accummulate.h"
namespace cutlass {
namespace gemm {
/******************************************************************************
* thread_accumulator (generic specialization)
******************************************************************************/
/**
* \brief Thread-level multiply-accumulate abstraction (generic specialization)
*
* The thread_accumulator class maintains a MxN tile of accumulators in
* registers to which MxNxK matrix products of two thread tiles A (MxK)
* and B (KxN) can be added, where:
* M = ThreadItemsY
* N = ThreadItemsX
* K = sizeof(dp_vector_t) / sizeof(value_t).
*
* In order to leverage architecture-specific "dot-product accumulate" ISA
* operations, K is dictated by the thread_accumulator class in the form of
* the member-type dp_vector_t, which defines a K-component vector of value_t.
* The multiplicand inputs A and B are provided as arrays of dp_vector_t having
* extents ThreadItemsY and ThreadItemsX, respectively. (In the single
* component "dp1" scenario where dp_vector_t == value_t and thus K == 1, the
* multiplication is simply the outer product of two vectors.)
*
* The accumulators are zero-initialized in a two-phase process (construction +
* initialization) that requires shared storage in the form of the member-type
* scratch_storage_t during construction. (A single scratch_storage_t instance
* can be uniformly referenced across all threads in the block during
* construction *if* the block is synchronized between construction and
* initialization.)
*
* NB: This generic class is not directly constructible. Architecture- and
* algorithm-specific template specializations will provide the API
* functionality prescribed here.
*/
template <
int ThreadItemsY, ///< Height of thread tile in accum_t
int ThreadItemsX, ///< Width of thread tile in accum_t
typename value_t, ///< Multiplicand value type
typename accum_t, ///< Accumulator value type
int ACCUM_BYTES = ///< Size in bytes of accum_t
sizeof(accum_t),
arch_family_t::kind_t ArchFamily = ///< Architectural family enumerant
CUTLASS_ARCH_FAMILY>
struct thread_accumulator
{
protected:
//-------------------------------------------------------------------------
// Constants and types
//-------------------------------------------------------------------------
/// Specialized dot-product traits type
typedef dp_accummulate<value_t, accum_t> dp_accum_traits_t;
public:
//-------------------------------------------------------------------------
// Member types
//-------------------------------------------------------------------------
/// Dot-product vector type
typedef typename dp_accum_traits_t::dp_vector_t dp_vector_t;
/// Scratch storage layout
struct scratch_storage_t {};
protected:
//-------------------------------------------------------------------------
// Data members
//-------------------------------------------------------------------------
/// Thread's tile of accumulators
accum_t accumulators[ThreadItemsY][ThreadItemsX];
//-------------------------------------------------------------------------
// Utility methods
//-------------------------------------------------------------------------
/**
* Compute a multiply-add at accumulator coordinates (x, y)
*/
inline __device__
void mad_xy(
dp_vector_t (&tile_a)[ThreadItemsY],
dp_vector_t (&tile_b)[ThreadItemsX],
int x,
int y)
{
dp_accum_traits_t::mad(
accumulators[y][x],
tile_a[y],
tile_b[x],
accumulators[y][x]);
}
public:
//-------------------------------------------------------------------------
// Constructor API
//-------------------------------------------------------------------------
/// Constructor
inline __device__
thread_accumulator(
scratch_storage_t &scratch)
{}
//-------------------------------------------------------------------------
// Accumulator API
//-------------------------------------------------------------------------
/**
* \brief Zero-initialize thread accumulators.
*
* If a common reference to a single block-wide shared instance of scratch_storage_t
* is used during construction, the block must be synchronized after construction
* but prior to the invocation of init().
*/
inline __device__
void init()
{
#pragma unroll
for (int y = 0; y < ThreadItemsY; ++y) {
#pragma unroll
for (int x = 0; x < ThreadItemsX; ++x)
{
accumulators[y][x] = accum_t(0);
}
}
}
/**
* Retrieve the accumulator at thread tile coordinates (x, y)
*/
inline __device__
accum_t get(int x, int y)
{
// Accumulators are row-major
return accumulators[y][x];
}
/**
* \brief Compute the product of tile_a and tile_b and add the result to
* the tile of accumulators.
*/
inline __device__
void multiply_accumulate(
dp_vector_t (&tile_a)[ThreadItemsY],
dp_vector_t (&tile_b)[ThreadItemsX])
{
// Simply traverse the accumulator tile in row-major order
#pragma unroll
for (int y = 0; y < ThreadItemsY; ++y)
{
#pragma unroll
for (int x = 0; x < ThreadItemsX; ++x)
{
mad_xy(tile_a, tile_b, x, y);
}
}
}
};
/******************************************************************************
* thread_accumulator (__half->__half specialization)
******************************************************************************/
/**
* \brief Thread-level multiply-accumulate abstraction (__half->__half specialization)
*
* NB: Because we use the 2-item SIMD instruction HFMA2:
* - ThreadItemsX must be an even multiple of 2
* - ThreadItemsY must be an even multiple of 2
*
*/
template <
int ThreadItemsY, ///< Height in rows of thread tile in C
int ThreadItemsX, ///< Width in columns of thread tile in C
arch_family_t::kind_t ArchFamily> ///< Architectural family enumerant
struct thread_accumulator<
ThreadItemsY,
ThreadItemsX,
__half, ///< Multiplicand value type (matrices A and B)
__half, ///< Accumulator value type (matrix C and scalars)
2, ///< Size in bytes of accum_t
ArchFamily>
{
protected:
//-------------------------------------------------------------------------
// Constants and types
//-------------------------------------------------------------------------
/// Constants
enum
{
/// Height of thread tile in column-major uint32_t SIMD pairs along Y dimension
ThreadTilePairsY = divide_assert<ThreadItemsY, 2>::value,
/// Width of thread tile in column-major uint32_t SIMD pairs along X dimension
ThreadTilePairsX = ThreadItemsX,
/// Number of SIMD pairs in thread's slice of block-wide tile multiplicand A
ThreadPairsA = divide_assert<ThreadItemsY, 2>::value,
/// Number of SIMD pairs in thread's slice of block-wide tile multiplicand B
ThreadPairsB = divide_assert<ThreadItemsX, 2>::value,
};
public:
//-------------------------------------------------------------------------
// Member types
//-------------------------------------------------------------------------
/// Dot-product vector type
typedef __half dp_vector_t;
/// Scratch storage layout
struct scratch_storage_t {};
private:
//-------------------------------------------------------------------------
// Members
//-------------------------------------------------------------------------
/// Thread's tile of C accumulator pairs (the uint32_t SIMD pairs are
/// column-major, the 2D tile layout is also column-major)
uint32_t accumulator_pairs[ThreadTilePairsX][ThreadTilePairsY];
//-------------------------------------------------------------------------
// Utility methods
//-------------------------------------------------------------------------
/**
* Compute an HFMA2 MAD
*/
inline __device__ void mad(
uint32_t &d,
const uint32_t &a,
const uint32_t &b,
const uint32_t &c)
{
asm volatile ("fma.rn.f16x2 %0, %1, %2, %3;\n"
: "=r"(d) : "r"(a), "r"(b), "r"(c));
}
/**
* Compute an HFMA2 MAD with replicated b.lo:
* d{hi} = a{hi} * b{lo} + c{hi};
* d{lo} = a{lo} * b{lo} + c{lo};
*/
inline __device__ void mad_replicate_low(
uint32_t &d,
const uint32_t &a,
const uint32_t &b,
const uint32_t &c)
{
// Replicate low halves of b
uint32_t replicate;
asm volatile (
"{"
" .reg .b16 b_low,b_high;\n"
" mov.b32 {b_low,b_high}, %1;\n"
" mov.b32 %0, {b_low,b_low};\n"
"}" : "=r"(replicate) : "r"(b));
mad(d, a, replicate, c);
}
/**
* Compute an HFMA2 MAD with replicated b.hi:
* d{hi} = a{hi} * b{hi} + c{hi};
* d{lo} = a{lo} * b{hi} + c{lo};
*/
inline __device__ void mad_replicate_high(
uint32_t &d,
const uint32_t &a,
const uint32_t &b,
const uint32_t &c)
{
// Replicate high halves of b
uint32_t replicate;
asm volatile (
"{"
" .reg .b16 b_low,b_high;\n"
" mov.b32 {b_low,b_high}, %1;\n"
" mov.b32 %0, {b_high,b_high};\n"
"}" : "=r"(replicate) : "r"(b));
mad(d, a, replicate, c);
}
/**
* Compute a multiply-add at accumulator SIMD-pair coordinates (pair_x, pair_y)
*/
inline __device__
void mad_xy_even(
uint32_t (&pairs_tile_a)[ThreadPairsA],
uint32_t (&pairs_tile_b)[ThreadPairsB],
int pair_x,
int pair_y)
{
// Even column: use low half of the b pair
mad_replicate_low(
accumulator_pairs[pair_x][pair_y],
pairs_tile_a[pair_y],
pairs_tile_b[pair_x / 2],
accumulator_pairs[pair_x][pair_y]);
}
/**
* Compute a multiply-add at accumulator SIMD-pair coordinates (pair_x, pair_y)
*/
inline __device__
void mad_xy_odd(
uint32_t (&pairs_tile_a)[ThreadPairsA],
uint32_t (&pairs_tile_b)[ThreadPairsB],
int pair_x,
int pair_y)
{
// Odd column: use high half of the b pair
mad_replicate_high(
accumulator_pairs[pair_x][pair_y],
pairs_tile_a[pair_y],
pairs_tile_b[pair_x / 2],
accumulator_pairs[pair_x][pair_y]);
}
public:
//-------------------------------------------------------------------------
// Constructor API
//-------------------------------------------------------------------------
/// Constructor
inline __device__
thread_accumulator(
scratch_storage_t &scratch)
{}
//-------------------------------------------------------------------------
// Accumulator API
//-------------------------------------------------------------------------
/**
* Zero-initialize thread accumulators.
*/
inline __device__
void init()
{
#pragma unroll
for (int y = 0; y < ThreadTilePairsY; ++y)
{
#pragma unroll
for (int x = 0; x < ThreadTilePairsX; ++x)
{
accumulator_pairs[x][y] = 0;
}
}
}
/**
* Retrieve the accumulator at thread tile coordinates (x, y)
*/
inline __device__
__half get(int x, int y)
{
// SIMD pairs are column-major
uint32_t pair = accumulator_pairs[x][y / 2];
return reinterpret_cast<__half (&)[2]>(pair)[y % 2];
}
/**
* \brief Compute the product of pairs_tile_a and pairs_tile_b and add the result to
* the tile of accumulators.
*/
inline __device__
void multiply_accumulate(
dp_vector_t (&tile_a)[ThreadItemsY],
dp_vector_t (&tile_b)[ThreadItemsX])
{
typedef uint32_t pairs_tile_a_t[ThreadPairsA];
typedef uint32_t pairs_tile_b_t[ThreadPairsB];
// Alias slices in pairs
pairs_tile_a_t &pairs_tile_a = reinterpret_cast<pairs_tile_a_t&>(tile_a);
pairs_tile_b_t &pairs_tile_b = reinterpret_cast<pairs_tile_b_t&>(tile_b);
// Simply traverse the accumulator tile in column-major order
#pragma unroll
for (int x = 0; x < ThreadTilePairsX; ++x)
{
#pragma unroll
for (int y = 0; y < ThreadTilePairsY; ++y)
{
mad_xy_even(pairs_tile_a, pairs_tile_b, x, y);
}
}
}
};
} // namespace gemm
} // namespace cutlass

View File

@@ -0,0 +1,207 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Thread-level multiply-accumulate abstraction
* (Volta 4B accum_t specialization)
*/
#include <mma.h>
#include "../util/util.h"
#include "dp_accummulate.h"
namespace cutlass {
namespace gemm {
/*!
*\brief matrix_layout to perform conversion between Cutlass types and WMMA types
*/
template <matrix_transform_t::kind_t>
struct matrix_layout;
/// Maps matrix_transform_t::NonTranspose to nvcuda::wmma::mem_col_major
template <>
struct matrix_layout<matrix_transform_t::NonTranspose>
{
/// Type tag in nvcuda::wmma namespace
typedef nvcuda::wmma::col_major tag;
/// Column major layout
static const nvcuda::wmma::layout_t kind = nvcuda::wmma::mem_col_major;
/// Cutlass matrix transform kind
static const matrix_transform_t::kind_t cutlass_kind = matrix_transform_t::NonTranspose;
};
/// Maps matrix_transform_t::NonTranspose to nvcuda::wmma::mem_row_major
template <>
struct matrix_layout<matrix_transform_t::Transpose>
{
/// Type tag in nvcuda::wmma namespace
typedef nvcuda::wmma::row_major tag;
/// Column major layout
static const nvcuda::wmma::layout_t kind = nvcuda::wmma::mem_row_major;
/// Cutlass matrix transform kind
static const matrix_transform_t::kind_t cutlass_kind = matrix_transform_t::Transpose;
};
/*!
* \brief Warp-synchronous matrix multiply-accumulate abstraction
*
* wmma_accumulator maps the CUDA WMMA API onto the GEMM structure
*/
template <
int WarpItemsY, /// Number of rows of the warp's accumulator tile
int WarpItemsX, /// Number of columns of the warp's accumulator tile
int WmmaItemsY, /// Number of rows in a single WMMA operation
int WmmaItemsX, /// Number of columns in a single WMMA operation
int WmmaItemsK, /// Inner dimension of WMMA operation
typename value_a_t, /// Type of A operand
typename value_b_t, /// Type of B operand
typename accum_t, /// Type of source and destination accumulators
matrix_transform_t::kind_t TransformA, /// Layout of A operand
matrix_transform_t::kind_t TransformB /// Layout of B operand
>
struct wmma_accumulator
{
public:
//-------------------------------------------------------------------------
// Constants and types
//-------------------------------------------------------------------------
enum
{
/// Number of WMMA blocks in warp row
WmmaBlocksX = divide_assert<WarpItemsX, WmmaItemsX>::value,
/// Number of WMMA blocks in a warp column
WmmaBlocksY = divide_assert<WarpItemsY, WmmaItemsY>::value,
};
/// Fragment type for matrix operand A
typedef nvcuda::wmma::fragment<
nvcuda::wmma::matrix_a,
WmmaItemsY,
WmmaItemsX,
WmmaItemsK,
value_a_t,
typename matrix_layout<TransformA>::tag>
fragment_a_t;
/// Fragment type for matrix operand B
typedef nvcuda::wmma::fragment<
nvcuda::wmma::matrix_b,
WmmaItemsY,
WmmaItemsX,
WmmaItemsK,
value_b_t,
typename matrix_layout<TransformB>::tag>
fragment_b_t;
/// Fragment type for accumulator
typedef nvcuda::wmma::fragment<
nvcuda::wmma::accumulator,
WmmaItemsY,
WmmaItemsX,
WmmaItemsK,
accum_t>
accumulator_t;
/// Scratch storage layout
struct scratch_storage_t
{
/// Initialization vector
uint4 zero_slab;
};
public:
//-------------------------------------------------------------------------
// Data members
//-------------------------------------------------------------------------
/// Thread's tile of accumulators
accumulator_t accumulators[WmmaBlocksX][WmmaBlocksY];
public:
//-------------------------------------------------------------------------
// Constructor API
//-------------------------------------------------------------------------
/// Constructor initializes accumulators to zero
inline __device__
wmma_accumulator()
{
init();
}
//-------------------------------------------------------------------------
// Accumulator API
//-------------------------------------------------------------------------
/**
* \brief Zero-initialize thread accumulators.
*/
inline __device__
void init()
{
#pragma unroll
for (int x = 0; x < WmmaBlocksX; ++x)
{
#pragma unroll
for (int y = 0; y < WmmaBlocksY; ++y)
{
nvcuda::wmma::fill_fragment(accumulators[x][y], accum_t(0));
}
}
}
/**
* \brief Compute the product of tile_a and tile_b and add the result to
* the tile of accumulators.
*/
inline __device__
void multiply_accumulate(
fragment_a_t (&tile_a)[WmmaBlocksY],
fragment_b_t (&tile_b)[WmmaBlocksX])
{
#pragma unroll
for (int x = 0; x < WmmaBlocksX; ++x)
{
#pragma unroll
for (int y = 0; y < WmmaBlocksY; ++y)
{
nvcuda::wmma::mma_sync(accumulators[x][y], tile_a[y], tile_b[x], accumulators[x][y]);
}
}
}
};
} // namespace gemm
} // namespace cutlass

112
cutlass/util/debug.h Normal file
View File

@@ -0,0 +1,112 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* \brief Debugging and logging functionality
*/
#include <stdio.h>
namespace cutlass {
/******************************************************************************
* Debug and logging macros
******************************************************************************/
/**
* Formats and prints the given message to stdout
*/
#if !defined(CUDA_LOG)
#if !defined(__CUDA_ARCH__)
#define CUDA_LOG(format, ...) printf(format,__VA_ARGS__)
#else
#define CUDA_LOG(format, ...) printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, blockIdx.x, blockIdx.y, blockIdx.z, threadIdx.x, threadIdx.y, threadIdx.z, __VA_ARGS__);
#endif
#endif
/**
* Formats and prints the given message to stdout only if DEBUG is defined
*/
#if !defined(CUDA_LOG_DEBUG)
#ifdef DEBUG
#define CUDA_LOG_DEBUG(format, ...) CUDA_LOG(format, __VA_ARGS__)
#else
#define CUDA_LOG_DEBUG(format, ...)
#endif
#endif
/**
* \brief The corresponding error message is printed to \p stderr (or \p stdout in device code) along with the supplied source context.
*
* \return The CUDA error.
*/
__host__ __device__ inline cudaError_t cuda_perror_impl(
cudaError_t error,
const char* filename,
int line)
{
(void)filename;
(void)line;
if (error)
{
#if !defined(__CUDA_ARCH__)
fprintf(stderr, "CUDA error %d [%s, %d]: %s\n", error, filename, line, cudaGetErrorString(error));
fflush(stderr);
#else
printf("CUDA error %d [%s, %d]\n", error, filename, line);
#endif
}
return error;
}
/**
* \brief Perror macro
*/
#ifndef CUDA_PERROR
#define CUDA_PERROR(e) cuda_perror_impl((cudaError_t) (e), __FILE__, __LINE__)
#endif
/**
* \brief Perror macro with exit
*/
#ifndef CUDA_PERROR_EXIT
#define CUDA_PERROR_EXIT(e) if (cuda_perror_impl((cudaError_t) (e), __FILE__, __LINE__)) { exit(1); }
#endif
/**
* \brief Perror macro only if DEBUG is defined
*/
#ifndef CUDA_PERROR_DEBUG
#ifdef DEBUG
#define CUDA_PERROR_DEBUG(e) CUDA_PERROR(e)
#else
#define CUDA_PERROR_DEBUG(e) (e)
#endif
#endif
} // namespace cutlass

View File

@@ -0,0 +1,216 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* \brief Utilities for device introspection
*/
#include "debug.h"
#include "nv_std.h"
#include "printable.h"
namespace cutlass {
/******************************************************************************
* math_operation_class_t
*
* Enumeration to select the appropriate math operation
*
* The assumption is multiple math operations may be used to compute GEMM
* for a given selection of operand and accumulator types.
*
******************************************************************************/
/// Math operation
enum class math_operation_class_t
{
scalar, // scalar (and vector) multiply-accumulate operations
matrix // Volta tensor operations
};
/******************************************************************************
* arch_family_t
******************************************************************************/
/**
* \brief Enumeration of NVIDIA GPU architectural families
*/
struct arch_family_t
{
/// \brief Enumerants
enum kind_t
{
Unsupported = 0,
Kepler = 3,
Maxwell = 5,
Volta = 7,
};
/// Enumerant value
kind_t kind;
/// Default constructor
arch_family_t() : kind(Unsupported) {}
/// Copy constructor
arch_family_t(const kind_t &other_kind) : kind(other_kind) {}
/// Cast to kind_t
operator kind_t() const { return kind; }
/// Returns the instance as a string
__host__ __device__ inline
char const* to_string() const
{
switch (kind)
{
case Kepler: return "Kepler";
case Maxwell: return "Maxwell";
case Volta: return "Volta";
case Unsupported:
default: return "Unsupported";
}
}
/// Insert the formatted instance into the output stream
void print(std::ostream& out) const { out << to_string(); }
};
/**
* Macro for architecture targeted by the current compiler pass
*/
#if defined(__CUDA_ARCH__)
#define CUTLASS_ARCH __CUDA_ARCH__
#else
#define CUTLASS_ARCH 0
#endif
/**
* Macro for architecture family targeted by the current compiler pass
*/
#define CUTLASS_ARCH_FAMILY \
( \
(CUTLASS_ARCH < 300) ? \
arch_family_t::Unsupported : \
(CUTLASS_ARCH < 500) ? \
arch_family_t::Kepler : \
(CUTLASS_ARCH < 700) ? \
arch_family_t::Maxwell : \
arch_family_t::Volta \
)
/******************************************************************************
* Device introspection
******************************************************************************/
/**
* Empty kernel for querying PTX manifest metadata (e.g., version) for the current device
*/
template <typename T>
__global__ void empty_kernel(void) { }
/**
* \brief Retrieves the PTX version that will be used on the current device (major * 100 + minor * 10)
*/
cudaError_t ptx_version(int &version)
{
struct Dummy
{
/// Type definition of the empty_kernel kernel entry point
typedef void (*EmptyKernelPtr)();
/// Force empty_kernel<void> to be generated if this class is used
EmptyKernelPtr Empty()
{
return empty_kernel<void>;
}
};
cudaError_t error = cudaSuccess;
do
{
cudaFuncAttributes empty_kernel_attrs;
if (CUDA_PERROR_DEBUG(error = cudaFuncGetAttributes(&empty_kernel_attrs, empty_kernel<void>))) break;
version = empty_kernel_attrs.ptxVersion * 10;
}
while (0);
return error;
}
/**
* \brief Retrieves the SM version (major * 100 + minor * 10) for the current device
*/
cudaError_t get_sm_version(int &sm_version)
{
cudaError_t error = cudaSuccess;
// Get device ordinal
int device_ordinal;
if (CUDA_PERROR_DEBUG(error = cudaGetDevice(&device_ordinal)))
return error;
// Fill in SM version
int major, minor;
if (CUDA_PERROR_DEBUG(error = cudaDeviceGetAttribute(&major, cudaDevAttrComputeCapabilityMajor, device_ordinal)))
return error;
if (CUDA_PERROR_DEBUG(error = cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_ordinal)))
return error;
sm_version = major * 100 + minor * 10;
return error;
}
/**
* \brief Retrieves the count for the current device
*/
cudaError_t get_sm_count(int &sm_count)
{
cudaError_t error = cudaSuccess;
// Get device ordinal
int device_ordinal;
if (CUDA_PERROR_DEBUG(error = cudaGetDevice(&device_ordinal)))
return error;
// Get SM count
if (CUDA_PERROR_DEBUG(error = cudaDeviceGetAttribute (&sm_count, cudaDevAttrMultiProcessorCount, device_ordinal)))
return error;
return error;
}
} // namespace cutlass

View File

@@ -0,0 +1,484 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* \brief I/O device intrinsics
*/
#include <stdint.h>
#include <cuda_fp16.h>
#include "nv_std.h"
#include "math.h"
namespace cutlass {
/******************************************************************************
* io_vector
******************************************************************************/
/**
* Base aligned storage for IO vector
*/
template <typename value_t, int VectorItems, int AlignBytes> struct io_vector_base;
template <typename value_t, int VectorItems> struct __align__(1) io_vector_base<value_t, VectorItems, 1> { value_t buff[VectorItems]; };
template <typename value_t, int VectorItems> struct __align__(2) io_vector_base<value_t, VectorItems, 2> { value_t buff[VectorItems]; };
template <typename value_t, int VectorItems> struct __align__(4) io_vector_base<value_t, VectorItems, 4> { value_t buff[VectorItems]; };
template <typename value_t, int VectorItems> struct __align__(8) io_vector_base<value_t, VectorItems, 8> { value_t buff[VectorItems]; };
template <typename value_t, int VectorItems> struct __align__(16) io_vector_base<value_t, VectorItems, 16> { value_t buff[VectorItems]; };
/**
* \brief Aligned vector type for coarsening data movement instructions
*
* Exposes the member constant \p VectorItems, the actual number of component
* values comprising the io_vector
*/
template <
typename value_t, ///< Component value type
int MaxVectorItems, ///< Maximum allowable component values
int MaxAlignBytes ///< Maximum allowable alignment
= __NV_STD_MIN(16, MaxVectorItems * sizeof(value_t)),
int AlignBytes ///< Actual alignment
= __NV_STD_MIN(sizeof(value_t) * MaxVectorItems, MaxAlignBytes),
int VectorItems ///< Actual number of component values
= divide_assert<AlignBytes, sizeof(value_t)>::value,
bool MustAlias ///< Whether we need to alias during loads/stores
= (VectorItems > 4)>
struct io_vector;
/**
* IO vector (specialization for VectorItems <= 4)
*/
template <
typename value_t,
int MaxVectorItems,
int MaxAlignBytes,
int _AlignBytes,
int _VectorItems>
struct io_vector <
value_t,
MaxVectorItems,
MaxAlignBytes,
_AlignBytes,
_VectorItems,
false>
:
io_vector_base<value_t, _VectorItems, _AlignBytes>
{
enum
{
VectorItems = _VectorItems,
AlignBytes = _AlignBytes
};
static_assert(is_pow2<AlignBytes>::value, "I/O vector alignment must be a power-of-two.");
static_assert((AlignBytes <= 16), "I/O vector alignment must <= 16B.");
inline __device__
void load(const io_vector *ptr)
{
*this = *ptr;
}
inline __device__
void load(const value_t *ptr)
{
*this = *reinterpret_cast<const io_vector*>(ptr);
}
inline __device__
void store(io_vector *ptr) const
{
*ptr = *this;
}
inline __device__
void store(value_t *ptr) const
{
*reinterpret_cast<io_vector*>(ptr) = *this;
}
};
/**
* IO vector (specialization for VectorItems > 4)
*
* NB: Workaround for NVCC not generating 128-bit loads/stores for aligned
* structures having component types < 32b
*/
template <
typename value_t,
int MaxVectorItems,
int MaxAlignBytes,
int _AlignBytes,
int _VectorItems>
struct io_vector <
value_t,
MaxVectorItems,
MaxAlignBytes,
_AlignBytes,
_VectorItems,
true>
:
io_vector_base<value_t, _VectorItems, _AlignBytes>
{
enum
{
VectorItems = _VectorItems,
AlignBytes = _AlignBytes
};
static_assert(is_pow2<AlignBytes>::value, "I/O vector alignment must be a power-of-two.");
static_assert((AlignBytes <= 16), "I/O vector alignment must <= 16B.");
typedef typename nv_std::conditional<(AlignBytes == 8),
uint2, // Use 8B load
uint4> // Use 16B load
::type align_t;
inline __device__
void load(const io_vector *ptr)
{
*reinterpret_cast<align_t*>(this) = *reinterpret_cast<const align_t*>(ptr);
}
inline __device__
void load(const value_t *ptr)
{
*reinterpret_cast<align_t*>(this) = *reinterpret_cast<const align_t*>(ptr);
}
inline __device__
void store(io_vector *ptr) const
{
*reinterpret_cast<align_t*>(ptr) = *reinterpret_cast<const align_t*>(this);
}
inline __device__
void store(value_t *ptr) const
{
*reinterpret_cast<align_t*>(ptr) = *reinterpret_cast<const align_t*>(this);
}
};
/******************************************************************************
* Macro expansions for vector loads
******************************************************************************/
/**
* Define vector-4 LD specialization for the given load modifier
*/
#define CUTLASS_LD_V4(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
template <typename ptr_t> \
inline __device__ \
void f_name( \
value_t (&dest)[4], \
ptr_t ptr) \
{ \
asm volatile ("ld."#load_modifier".v4."#ptx_type" {%0, %1, %2, %3}, [%4];\n" \
: \
"="#val_constraint(dest[0]), \
"="#val_constraint(dest[1]), \
"="#val_constraint(dest[2]), \
"="#val_constraint(dest[3]) \
: \
#ptr_constraint(ptr)); \
}
/**
* Define vector-2 LD specialization for the given load modifier
*/
#define CUTLASS_LD_V2(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
template <typename ptr_t> \
inline __device__ \
void f_name( \
value_t (&dest)[2], \
ptr_t ptr) \
{ \
asm volatile ("ld."#load_modifier".v2."#ptx_type" {%0, %1}, [%2];\n" \
: \
"="#val_constraint(dest[0]), \
"="#val_constraint(dest[1]) \
: \
#ptr_constraint(ptr)); \
}
/**
* Define vector-1 LD specialization for the given load modifier
*/
#define CUTLASS_LD_V1(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
template <typename ptr_t> \
inline __device__ \
void f_name( \
value_t (&dest)[1], \
ptr_t ptr) \
{ \
asm volatile ("ld."#load_modifier"."#ptx_type" %0, [%1];\n" \
: \
"="#val_constraint(dest[0]) \
: \
#ptr_constraint(ptr)); \
}
/**
* Define powers-of-two vector LD specializations
*/
#define CUTLASS_LD_ALL(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
CUTLASS_LD_V4(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
CUTLASS_LD_V2(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
CUTLASS_LD_V1(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint)
/******************************************************************************
* Macro expansions for vector stores
******************************************************************************/
/**
* Define vector-4 ST specialization for the given load modifier
*/
#define CUTLASS_ST_V4(f_name, value_t, store_modifier, ptx_type, val_constraint, ptr_constraint) \
template <typename ptr_t> \
inline __device__ \
void f_name( \
ptr_t ptr, \
const value_t (&src)[4]) \
{ \
asm volatile ("st."#store_modifier".v4."#ptx_type" [%0], {%1, %2, %3, %4};\n" \
: : \
#ptr_constraint(ptr), \
#val_constraint(src[0]), \
#val_constraint(src[1]), \
#val_constraint(src[2]), \
#val_constraint(src[3])); \
}
/**
* Define vector-2 ST specialization for the given load modifier
*/
#define CUTLASS_ST_V2(f_name, value_t, store_modifier, ptx_type, val_constraint, ptr_constraint) \
template <typename ptr_t> \
inline __device__ \
void f_name( \
ptr_t ptr, \
const value_t (&src)[2]) \
{ \
asm volatile ("st."#store_modifier".v2."#ptx_type" [%0], {%1, %2};\n" \
: : \
#ptr_constraint(ptr), \
#val_constraint(src[0]), \
#val_constraint(src[1])); \
}
/**
* Define vector-1 ST specialization for the given load modifier
*/
#define CUTLASS_ST_V1(f_name, value_t, store_modifier, ptx_type, val_constraint, ptr_constraint) \
template <typename ptr_t> \
inline __device__ \
void f_name( \
ptr_t ptr, \
const value_t (&src)[1]) \
{ \
asm volatile ("st."#store_modifier"."#ptx_type" [%0], %1;\n" \
: : \
#ptr_constraint(ptr), \
#val_constraint(src[0])); \
}
/**
* Define powers-of-two vector LD specializations
*/
#define CUTLASS_ST_ALL(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
CUTLASS_ST_V4(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
CUTLASS_ST_V2(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint) \
CUTLASS_ST_V1(f_name, value_t, load_modifier, ptx_type, val_constraint, ptr_constraint)
/******************************************************************************
* Macro expansions for vector IO
******************************************************************************/
/**
* Define global and shared LD specializations
*/
#define CUTLASS_IO(value_t, ptx_type, val_constraint) \
CUTLASS_LD_ALL(ldg_cg_internal, value_t, global.cg, ptx_type, val_constraint, l) \
CUTLASS_ST_ALL(stg_cg_internal, value_t, global.cg, ptx_type, val_constraint, l)
// Define IO for useful types
CUTLASS_IO(double, f64, d)
CUTLASS_IO(float, f32, f)
CUTLASS_IO(int64_t, b64, l)
CUTLASS_IO(int32_t, b32, r)
CUTLASS_IO(int16_t, b16, h)
// Macro cleanup
#undef CUTLASS_IO
#undef CUTLASS_LD_ALL
#undef CUTLASS_LD_V4
#undef CUTLASS_LD_V2
#undef CUTLASS_LD_V1
#undef CUTLASS_ST_ALL
#undef CUTLASS_ST_V4
#undef CUTLASS_ST_V2
#undef CUTLASS_ST_V1
/******************************************************************************
* I/O cast types
******************************************************************************/
/// Provides the type for which to reinterpret-cast a given vector
template <
typename value_t,
int IoVecDim,
int ValueBytes = sizeof(value_t)>
struct io_cast
{
typedef value_t type[IoVecDim];
};
/// Provides the type for which to reinterpret-cast a vector of 1B types
template <
typename value_t,
int IoVecDim>
struct io_cast<value_t, IoVecDim, 1>
{
typedef typename nv_std::conditional<
(IoVecDim < 2),
int8_t[1], // Use 8b load
typename nv_std::conditional<
(IoVecDim < 4),
int16_t[1], // Use 16b load
int32_t[IoVecDim / 4]>::type>::type // Use up to 128b load
type;
};
/// Provides the type for which to reinterpret-cast a vector of 2B types
template <
typename value_t,
int IoVecDim>
struct io_cast<value_t, IoVecDim, 2>
{
typedef typename nv_std::conditional<
(IoVecDim < 2),
int16_t[1], // Use 16b load
int32_t[IoVecDim / 2]>::type // Use up to 128b load
type;
};
/******************************************************************************
* ldg_cg intrinsics
******************************************************************************/
/// Load from global (cache-global modifier)
template <typename value_t, typename ptr_t>
inline __device__
void ldg_cg(
value_t &dest,
ptr_t d_in)
{
// Cast dest to a different array type if necessary
ldg_cg_internal(
reinterpret_cast<typename io_cast<value_t, 1>::type &>(dest),
d_in);
}
/// Load from global (cache-global modifier)
template <typename value_t, int IoVecDim, typename ptr_t>
inline __device__
void ldg_cg(
value_t (&dest)[IoVecDim],
ptr_t d_in)
{
static_assert(is_pow2<IoVecDim>::value, "I/O vectors must be a power-of-two.");
// Cast dest to a different array type if necessary
ldg_cg_internal(
reinterpret_cast<typename io_cast<value_t, IoVecDim>::type &>(dest),
d_in);
}
/******************************************************************************
* stg_cg intrinsics
******************************************************************************/
/// Store to global (cache-global modifier)
template <typename ptr_t, typename value_t>
inline __device__
void stg_cg(
ptr_t dest,
const value_t &src)
{
// Cast src to a different array type if necessary
stg_cg_internal(
dest,
reinterpret_cast<const typename io_cast<value_t, 1>::type &>(src));
}
/// Store to global (cache-global modifier)
template <typename ptr_t, int IoVecDim, typename value_t>
inline __device__
void stg_cg(
ptr_t dest,
const value_t (&src)[IoVecDim])
{
static_assert(is_pow2<IoVecDim>::value, "I/O vectors must be a power-of-two.");
// Cast src to a different array type if necessary
stg_cg_internal(
dest,
reinterpret_cast<const typename io_cast<value_t, IoVecDim>::type &>(src));
}
} // namespace cutlass

189
cutlass/util/math.h Normal file
View File

@@ -0,0 +1,189 @@
/*
* Copyright 1993-2017 NVIDIA Corporation. All rights reserved.
*
* NOTICE TO LICENSEE:
*
* This source code and/or documentation ("Licensed Deliverables") are
* subject to NVIDIA intellectual property rights under U.S. and
* international Copyright laws.
*
* These Licensed Deliverables contained herein is PROPRIETARY and
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
* conditions of a form of NVIDIA software license agreement by and
* between NVIDIA and Licensee ("License Agreement") or electronically
* accepted by Licensee. Notwithstanding any terms or conditions to
* the contrary in the License Agreement, reproduction or disclosure
* of the Licensed Deliverables to any third party without the express
* written consent of NVIDIA is prohibited.
*
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
* OF THESE LICENSED DELIVERABLES.
*
* U.S. Government End Users. These Licensed Deliverables are a
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
* 1995), consisting of "commercial computer software" and "commercial
* computer software documentation" as such terms are used in 48
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
* U.S. Government End Users acquire the Licensed Deliverables with
* only those rights set forth herein.
*
* Any use of the Licensed Deliverables in individual and commercial
* software must include, in the user documentation and internal
* comments to the code, the above Disclaimer and U.S. Government End
* Users Notice.
*/
#pragma once
/**
* \file
* \brief Math utilities
*/
#include "nv_std.h"
namespace cutlass {
/******************************************************************************
* Static math utilities
******************************************************************************/
/**
* Statically determine if N is a power-of-two
*/
template <int N>
struct is_pow2 : nv_std::integral_constant<bool, (N & (N - 1)) == 0>
{};
/**
* Statically determine log2(N), rounded down
*/
template <int N, int CurrentVal = N, int Count = 0>
struct log2_down
{
/// Static logarithm value
enum { value = log2_down<N, (CurrentVal >> 1), Count + 1>::value };
};
// Base case
template <int N, int Count>
struct log2_down<N, 1, Count>
{
enum { value = Count };
};
/**
* Statically determine log2(N), rounded up
*/
template <int N, int CurrentVal = N, int Count = 0>
struct log2_up
{
/// Static logarithm value
enum { value = log2_up<N, (CurrentVal >> 1), Count + 1>::value };
};
// Base case
template <int N, int Count>
struct log2_up<N, 1, Count>
{
enum { value = ((1 << Count) < N) ? Count + 1 : Count };
};
/**
* Statically estimate sqrt(N) to the nearest power-of-two
*/
template <int N>
struct sqrt_est
{
enum { value = 1 << (log2_up<N>::value / 2) };
};
/**
* For performing a constant-division with a compile-time assertion that the
* Divisor evenly-divides the Dividend.
*/
template <int Dividend, int Divisor>
struct divide_assert
{
enum { value = Dividend / Divisor};
static_assert((Dividend % Divisor == 0), "Not an even multiple");
};
/******************************************************************************
* Rounding
******************************************************************************/
/**
* Round dividend up to the nearest multiple of divisor
*/
template <typename dividend_t, typename divisor_t>
inline __host__ __device__
dividend_t round_nearest(dividend_t dividend, divisor_t divisor)
{
return ((dividend + divisor - 1) / divisor) * divisor;
}
/**
* Greatest common divisor
*/
template <typename value_t>
inline __host__ __device__
value_t gcd(value_t a, value_t b)
{
for (;;)
{
if (a == 0) return b;
b %= a;
if (b == 0) return a;
a %= b;
}
}
/**
* Least common multiple
*/
template <typename value_t>
inline __host__ __device__
value_t lcm(value_t a, value_t b)
{
value_t temp = gcd(a, b);
return temp ? (a / temp * b) : 0;
}
} // namespace cutlass

View File

@@ -0,0 +1,94 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* \brief Enumeration of dense matrix view transformations
*/
#include "printable.h"
namespace cutlass {
/******************************************************************************
* matrix_transform_t
******************************************************************************/
/**
* \brief Enumeration of dense matrix view transformations
*
* These enumerators (and corresponding tag types) describe which view
* transformation needs to be applied prior to operation upon a given dense
* matrix. Its values correspond to Fortran characters 'n' (non-transpose),
* 't'(transpose) and 'c'(conjugate transpose) that are often
* used as parameters to legacy BLAS implementations
*/
struct matrix_transform_t : printable_t
{
/// \brief Enumerants (same as CUBLAS)
enum kind_t
{
/// Invalid view
Invalid = -1,
/// Non-transpose view
NonTranspose = 0,
/// Transpose view
Transpose = 1,
/// Conjugate transpose view
ConjugateTranpose = 2,
};
/// Enumerant value
kind_t kind;
/// Default constructor
matrix_transform_t() : kind(Invalid) {}
/// Copy constructor
matrix_transform_t(const kind_t &other_kind) : kind(other_kind) {}
/// Cast to kind_t
operator kind_t() const { return kind; }
/// Returns the instance as a string
__host__ __device__ inline
char const* to_string() const
{
switch (kind)
{
case NonTranspose: return "NonTranspose";
case Transpose: return "Transpose";
case ConjugateTranpose: return "ConjugateTranpose";
default: return "Invalid";
}
}
/// Insert the formatted instance into the output stream
void print(std::ostream& out) const { out << to_string(); }
};
} // namespace cutlass

727
cutlass/util/nv_std.h Normal file
View File

@@ -0,0 +1,727 @@
/*
* Copyright 1993-2017 NVIDIA Corporation. All rights reserved.
*
* NOTICE TO LICENSEE:
*
* This source code and/or documentation ("Licensed Deliverables") are
* subject to NVIDIA intellectual property rights under U.S. and
* international Copyright laws.
*
* These Licensed Deliverables contained herein is PROPRIETARY and
* CONFIDENTIAL to NVIDIA and is being provided under the terms and
* conditions of a form of NVIDIA software license agreement by and
* between NVIDIA and Licensee ("License Agreement") or electronically
* accepted by Licensee. Notwithstanding any terms or conditions to
* the contrary in the License Agreement, reproduction or disclosure
* of the Licensed Deliverables to any third party without the express
* written consent of NVIDIA is prohibited.
*
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, NVIDIA MAKES NO REPRESENTATION ABOUT THE
* SUITABILITY OF THESE LICENSED DELIVERABLES FOR ANY PURPOSE. IT IS
* PROVIDED "AS IS" WITHOUT EXPRESS OR IMPLIED WARRANTY OF ANY KIND.
* NVIDIA DISCLAIMS ALL WARRANTIES WITH REGARD TO THESE LICENSED
* DELIVERABLES, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANTABILITY,
* NONINFRINGEMENT, AND FITNESS FOR A PARTICULAR PURPOSE.
* NOTWITHSTANDING ANY TERMS OR CONDITIONS TO THE CONTRARY IN THE
* LICENSE AGREEMENT, IN NO EVENT SHALL NVIDIA BE LIABLE FOR ANY
* SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL DAMAGES, OR ANY
* DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
* WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS
* ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE
* OF THESE LICENSED DELIVERABLES.
*
* U.S. Government End Users. These Licensed Deliverables are a
* "commercial item" as that term is defined at 48 C.F.R. 2.101 (OCT
* 1995), consisting of "commercial computer software" and "commercial
* computer software documentation" as such terms are used in 48
* C.F.R. 12.212 (SEPT 1995) and is provided to the U.S. Government
* only as a commercial end item. Consistent with 48 C.F.R.12.212 and
* 48 C.F.R. 227.7202-1 through 227.7202-4 (JUNE 1995), all
* U.S. Government End Users acquire the Licensed Deliverables with
* only those rights set forth herein.
*
* Any use of the Licensed Deliverables in individual and commercial
* software must include, in the user documentation and internal
* comments to the code, the above Disclaimer and U.S. Government End
* Users Notice.
*/
#pragma once
/**
* \file
* \brief C++ features that may be otherwise unimplemented for CUDA device functions.
*
* This file has three components:
*
* (1) Macros:
* - Empty macro defines for C++ keywords not supported by the current
* version of C++. These simply allow compilation to proceed (but do
* not provide the added semantics).
* - \p noexcept
* - \p constexpr
* - \p nullptr
* - \p static_assert
*
* - Macro functions that we need in constant expressions because the
* C++ equivalents require constexpr compiler support. These are
* prefixed with \p __NV_STD_*
* - \p __NV_STD_MAX
* - \p __NV_STD_MIN
*
* (2) Re-implementations of STL functions and types:
* - C++ features that need the \p __device__ annotation. These are
* placed into the \p nv_std namespace.
* - \p plus
* - \p less
* - \p greater
* - \p min
* - \p max
* - \p methods on std::pair (==, !=, <, <=, >, >=, and make_pair())
*
* (3) Stop-gap implementations of unsupported STL functions and types:
* - STL functions and types defined by C++ 11/14/17/etc. that are not
* provided by the current version of C++. These are placed into the
* \p nv_std namespace
* - \p integral_constant
* - \p nullptr_t
* - \p true_type
* - \p false_type
* - \p bool_constant
* - \p enable_if
* - \p conditional
* - \p is_same
* - \p is_base_of
* - \p remove_const
* - \p remove_volatile
* - \p remove_cv
* - \p is_volatile
* - \p is_pointer
* - \p is_void
* - \p is_integral
* - \p is_floating_point
* - \p is_arithmetic
* - \p is_fundamental
* - \p is_trivially_copyable
* - \p alignment_of
* - \p aligned_storage
*
* (4) Functions and types that are STL-like (but aren't in the STL):
* - \p TODO: min and max functors?
*
* The idea is that, as we drop support for older compilers, we can simply #define
* the \p __NV_STD_XYZ macros and \p nv_std namespace to alias their C++
* counterparts (or trivially find-and-replace their occurrences in code text).
*/
//-----------------------------------------------------------------------------
// Include STL files that nv_std provides functionality for
//-----------------------------------------------------------------------------
#include <cstddef> // nullptr_t
#include <algorithm> // Minimum/maximum operations
#include <functional> // Arithmetic operations
#include <utility> // For methods on std::pair
#if (!defined(_MSC_VER) && (__cplusplus >= 201103L)) || (defined(_MSC_VER) && (_MS_VER >= 1500))
#include <type_traits> // For integral constants, conditional metaprogramming, and type traits
#endif
/******************************************************************************
* Macros
******************************************************************************/
//-----------------------------------------------------------------------------
// Keywords
//-----------------------------------------------------------------------------
/// noexcept, constexpr
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1900))
#ifndef noexcept
#define noexcept
#endif
#ifndef constexpr
#define constexpr
#endif
#endif
/// nullptr
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1310 ))
#ifndef nullptr
#define nullptr 0
#endif
#endif
/// static_assert
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1600 ))
#ifndef static_assert
#define __nv_std_cat_(a, b) a ## b
#define __nv_std_cat(a, b) __nv_std_cat_(a, b)
#define static_assert(__e, __m) typedef int __nv_std_cat(AsSeRt, __LINE__)[(__e) ? 1 : -1]
#endif
#endif
//-----------------------------------------------------------------------------
// Functions
//-----------------------------------------------------------------------------
/// Select maximum(a, b)
#ifndef __NV_STD_MAX
#define __NV_STD_MAX(a, b) (((b) > (a)) ? (b) : (a))
#endif
/// Select minimum(a, b)
#ifndef __NV_STD_MIN
#define __NV_STD_MIN(a, b) (((b) < (a)) ? (b) : (a))
#endif
/******************************************************************************
* Re-implementations
******************************************************************************/
namespace nv_std {
//-----------------------------------------------------------------------------
// Arithmetic operations, comparisons <functional>
//-----------------------------------------------------------------------------
/// nv_std::plus
template <typename T>
struct plus
{
inline __host__ __device__
constexpr T operator()(const T &lhs, const T &rhs) const
{
return lhs + rhs;
}
};
/// std::less
template <typename T>
struct less
{
inline __host__ __device__
constexpr bool operator()(const T &lhs, const T &rhs) const
{
return lhs < rhs;
}
};
/// std::greater
template <typename T>
struct greater
{
inline __host__ __device__
constexpr bool operator()(const T &lhs, const T &rhs) const
{
return lhs > rhs;
}
};
//-----------------------------------------------------------------------------
// Minimum/maximum operations <algorithm>
//-----------------------------------------------------------------------------
/// std::min
template <typename T>
inline __host__ __device__
constexpr const T& min(
const T& a,
const T& b)
{
return (b < a) ? b : a;
}
/// std::max
template <typename T>
inline __host__ __device__
constexpr const T& max(
const T& a,
const T& b)
{
return (a < b) ? b : a;
}
//-----------------------------------------------------------------------------
// Methods on std::pair
//-----------------------------------------------------------------------------
using std::pair;
template< class T1, class T2 >
inline __host__ __device__
constexpr bool operator==( const pair<T1,T2>& lhs, const pair<T1,T2>& rhs )
{
return (lhs.first == rhs.first) && (lhs.second == rhs.second);
}
template< class T1, class T2 >
inline __host__ __device__
constexpr bool operator!=( const pair<T1,T2>& lhs, const pair<T1,T2>& rhs )
{
return (lhs.first != rhs.first) && (lhs.second != rhs.second);
}
template< class T1, class T2 >
inline __host__ __device__
constexpr bool operator<( const pair<T1,T2>& lhs, const pair<T1,T2>& rhs )
{
return (lhs.first < rhs.first) ?
true :
(rhs.first < lhs.first) ?
false :
(lhs.second < rhs.second);
}
template< class T1, class T2 >
inline __host__ __device__
constexpr bool operator<=( const pair<T1,T2>& lhs, const pair<T1,T2>& rhs )
{
return !(rhs < lhs);
}
template< class T1, class T2 >
inline __host__ __device__
constexpr bool operator>( const pair<T1,T2>& lhs, const pair<T1,T2>& rhs )
{
return (rhs < lhs);
}
template< class T1, class T2 >
inline __host__ __device__
constexpr bool operator>=( const pair<T1,T2>& lhs, const pair<T1,T2>& rhs )
{
return !(lhs < rhs);
}
template< class T1, class T2 >
inline __host__ __device__
std::pair<T1,T2> make_pair( T1 t, T2 u )
{
std::pair<T1,T2> retval;
retval.first = t;
retval.second = u;
return retval;
}
} // namespace nv_std
/******************************************************************************
* Implementations of C++ 11/14/17/... STL features
******************************************************************************/
namespace nv_std {
//-----------------------------------------------------------------------------
// Integral constant helper types <type_traits>
//-----------------------------------------------------------------------------
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
/// std::integral_constant
template <typename value_t, value_t V>
struct integral_constant;
/// std::integral_constant
template <typename value_t, value_t V>
struct integral_constant
{
static const value_t value = V;
typedef value_t value_type;
typedef integral_constant<value_t, V> type;
inline __host__ __device__ operator value_type() const
{
return value;
}
inline __host__ __device__ const value_type operator()() const
{
return value;
}
};
#else
using std::integral_constant;
using std::pair;
#endif
/// The type used as a compile-time boolean with true value.
typedef integral_constant<bool, true> true_type;
/// The type used as a compile-time boolean with false value.
typedef integral_constant<bool, false> false_type;
#if (!defined(_MSC_VER) && (__cplusplus < 201402L)) || (defined(_MSC_VER) && (_MSC_VER < 1900))
/// std::bool_constant
template <bool V>
struct bool_constant : nv_std::integral_constant<bool, V>
{};
#else
using std::bool_constant;
#endif
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1700))
/// std::nullptr_t
struct nullptr_t {};
#else
using std::nullptr_t;
#endif
//-----------------------------------------------------------------------------
// Conditional metaprogramming <type_traits>
//-----------------------------------------------------------------------------
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1600))
/// std::enable_if (true specialization)
template<bool C, typename T = void>
struct enable_if {
typedef T type;
};
/// std::enable_if (false specialization)
template<typename T>
struct enable_if<false, T> { };
/// std::conditional (true specialization)
template<bool B, class T, class F>
struct conditional { typedef T type; };
/// std::conditional (false specialization)
template<class T, class F>
struct conditional<false, T, F> { typedef F type; };
#else
using std::enable_if;
using std::conditional;
#endif
//-----------------------------------------------------------------------------
// Const/volatility specifiers <type_traits>
//-----------------------------------------------------------------------------
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
/// std::remove_const (non-const specialization)
template <typename T> struct remove_const { typedef T type; };
/// std::remove_const (const specialization)
template <typename T> struct remove_const<const T> { typedef T type; };
/// std::remove_volatile (non-volatile specialization)
template <typename T> struct remove_volatile { typedef T type; };
/// std::remove_volatile (volatile specialization)
template <typename T> struct remove_volatile<volatile T> { typedef T type; };
/// std::remove_cv
template <typename T>
struct remove_cv {
typedef typename remove_volatile<typename remove_const<T>::type>::type type;
};
#else
using std::remove_const;
using std::remove_volatile;
using std::remove_cv;
#endif
//-----------------------------------------------------------------------------
// Type relationships <type_traits>
//-----------------------------------------------------------------------------
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
/// std::is_same (false specialization)
template <typename A, typename B>
struct is_same : false_type
{};
/// std::is_same (true specialization)
template <typename A>
struct is_same<A, A> : true_type
{};
/// Helper for std::is_base_of
template<typename BaseT, typename DerivedT>
struct is_base_of_helper
{
typedef char (&yes)[1];
typedef char (&no)[2];
template<typename B, typename D>
struct dummy
{
operator B*() const;
operator D*();
};
template<typename T>
static yes check(DerivedT*, T);
static no check(BaseT*, int);
static const bool value = sizeof(check(dummy<BaseT, DerivedT>(), int())) == sizeof(yes);
};
/// std::is_base_of
template <typename BaseT, typename DerivedT>
struct is_base_of : integral_constant<
bool,
(is_base_of_helper<typename remove_cv<BaseT>::type, typename remove_cv<DerivedT>::type>::value) ||
(is_same<typename remove_cv<BaseT>::type, typename remove_cv<DerivedT>::type>::value)>
{};
#else
using std::is_same;
using std::is_base_of;
#endif
//-----------------------------------------------------------------------------
// Type properties <type_traits>
//-----------------------------------------------------------------------------
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
/// std::is_volatile
template <typename T> struct is_volatile : false_type {};
template <typename T> struct is_volatile<volatile T> : true_type {};
/// Helper for std::is_pointer (false specialization)
template <typename T> struct is_pointer_helper : false_type {};
/// Helper for std::is_pointer (true specialization)
template <typename T> struct is_pointer_helper<T*> : true_type {};
/// std::is_pointer
template <typename T> struct is_pointer : is_pointer_helper<typename remove_cv<T>::type> {};
/// std::is_void
template <typename T>
struct is_void : is_same<void, typename remove_cv<T>::type>
{};
/// std::is_integral
template <typename T> struct is_integral : false_type {};
template <> struct is_integral<char> : true_type {};
template <> struct is_integral<signed char> : true_type {};
template <> struct is_integral<unsigned char> : true_type {};
template <> struct is_integral<short> : true_type {};
template <> struct is_integral<unsigned short> : true_type {};
template <> struct is_integral<int> : true_type {};
template <> struct is_integral<unsigned int> : true_type {};
template <> struct is_integral<long> : true_type {};
template <> struct is_integral<unsigned long> : true_type {};
template <> struct is_integral<long long> : true_type {};
template <> struct is_integral<unsigned long long> : true_type {};
template <typename T> struct is_integral<volatile T> : is_integral<T> {};
template <typename T> struct is_integral<const T> : is_integral<T> {};
template <typename T> struct is_integral<const volatile T> : is_integral<T> {};
/// std::is_floating_point
template <typename T>
struct is_floating_point : integral_constant<
bool,
(is_same<float, typename remove_cv<T>::type>::value ||
is_same<double, typename remove_cv<T>::type>::value)>
{};
/// std::is_arithmetic
template <typename T>
struct is_arithmetic :
integral_constant<bool, (is_integral<T>::value || is_floating_point<T>::value)>
{};
/// std::is_fundamental
template <typename T>
struct is_fundamental : integral_constant<
bool, (is_arithmetic<T>::value ||
is_void<T>::value ||
is_same<nullptr_t, typename remove_cv<T>::type>::value)>
{};
#else
using std::is_volatile;
using std::is_pointer;
using std::is_void;
using std::is_integral;
using std::is_floating_point;
using std::is_arithmetic;
using std::is_fundamental;
#endif
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || \
(defined(_MSC_VER) && (_MSC_VER < 1800)) || \
(defined(__GNUG__) && (__GNUC__ < 5))
/**
* std::is_trivially_copyable
*
* This implementation only evaluates true if T is fundamental or pointer
*
* Without help from partial template specializations provided by the user for
* a specific class or struct, this trait will never report that the specified
* class or struct is trivially-copyable ; this is always safe,
* if possibly sub-optimal.
*/
template <typename T>
struct is_trivially_copyable :
integral_constant<bool, (is_fundamental<T>::value || is_pointer<T>::value)>
{};
#else
using std::is_trivially_copyable;
#endif
//-----------------------------------------------------------------------------
// Alignment and layout utilities
//-----------------------------------------------------------------------------
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1500))
/// std::alignment_of
template <typename value_t>
struct alignment_of
{
struct pad
{
value_t val;
char byte;
};
enum
{
value = sizeof(pad) - sizeof(value_t)
};
};
#else
template <typename value_t>
struct alignment_of : std::alignment_of<value_t> {};
#endif
/* 16B specializations where 32-bit Win32 host compiler disagrees with device compiler */
template <> struct alignment_of<int4> { enum { value = 16 }; };
template <> struct alignment_of<uint4> { enum { value = 16 }; };
template <> struct alignment_of<float4> { enum { value = 16 }; };
template <> struct alignment_of<long4> { enum { value = 16 }; };
template <> struct alignment_of<ulong4> { enum { value = 16 }; };
template <> struct alignment_of<longlong2> { enum { value = 16 }; };
template <> struct alignment_of<ulonglong2> { enum { value = 16 }; };
template <> struct alignment_of<double2> { enum { value = 16 }; };
template <> struct alignment_of<longlong4> { enum { value = 16 }; };
template <> struct alignment_of<ulonglong4> { enum { value = 16 }; };
template <> struct alignment_of<double4> { enum { value = 16 }; };
// Specializations for volatile/const qualified types
template <typename value_t> struct alignment_of<volatile value_t> : alignment_of<value_t> {};
template <typename value_t> struct alignment_of<const value_t> : alignment_of<value_t> {};
template <typename value_t> struct alignment_of<const volatile value_t> : alignment_of<value_t> {};
#if (!defined(_MSC_VER) && (__cplusplus < 201103L)) || (defined(_MSC_VER) && (_MSC_VER < 1800))
template<size_t Align> struct aligned_chunk;
template<> struct __align__(1) aligned_chunk<1> { uint8_t buff; };
template<> struct __align__(2) aligned_chunk<2> { uint16_t buff; };
template<> struct __align__(4) aligned_chunk<4> { uint32_t buff; };
template<> struct __align__(8) aligned_chunk<8> { uint32_t buff[2]; };
template<> struct __align__(16) aligned_chunk<16> { uint32_t buff[4]; };
template<> struct __align__(32) aligned_chunk<32> { uint32_t buff[8]; };
template<> struct __align__(64) aligned_chunk<64> { uint32_t buff[16]; };
template<> struct __align__(128) aligned_chunk<128> { uint32_t buff[32]; };
template<> struct __align__(256) aligned_chunk<256> { uint32_t buff[64]; };
template<> struct __align__(512) aligned_chunk<512> { uint32_t buff[128]; };
template<> struct __align__(1024) aligned_chunk<1024> { uint32_t buff[256]; };
template<> struct __align__(2048) aligned_chunk<2048> { uint32_t buff[512]; };
template<> struct __align__(4096) aligned_chunk<4096> { uint32_t buff[1024]; };
/// std::aligned_storage
template <size_t Len, size_t Align>
struct aligned_storage
{
typedef aligned_chunk<Align> type[Len / sizeof(aligned_chunk<Align>)];
};
#else
using std::aligned_storage;
#endif
}; // namespace nv_std

64
cutlass/util/printable.h Normal file
View File

@@ -0,0 +1,64 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* \brief Pure virtual base class for printable types
*/
#include <iostream>
namespace cutlass {
/******************************************************************************
* printable_t
******************************************************************************/
/**
* Pure virtual base class for printable types
*/
struct printable_t
{
/// Returns the instance as a string
__host__ __device__ inline
virtual char const* to_string() const = 0;
/// Insert the formatted instance into the output stream
virtual void print(std::ostream& out) const = 0;
/// Destructor
virtual ~printable_t() {}
};
/// Insert the formatted \p printable into the output stream
std::ostream& operator<<(
std::ostream& out,
printable_t const& printable)
{
printable.print(out);
return out;
}
} // namespace cutlass

74
cutlass/util/util.h Normal file
View File

@@ -0,0 +1,74 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* \brief Umbrella header file for utilities
*/
#include "debug.h"
#include "device_introspection.h"
#include "io_intrinsics.h"
#include "math.h"
#include "nv_std.h"
#include "printable.h"
#include "matrix_transform.h"
namespace cutlass {
/******************************************************************************
* int_constant
******************************************************************************/
/**
* Shorthand for nv_std::integral_constant of int32_t type
*/
template <int V>
struct int_constant : nv_std::integral_constant<int32_t, V>
{};
/******************************************************************************
* Uninitialized
******************************************************************************/
/**
* \brief A storage-backing wrapper that allows types with non-trivial constructors to be aliased in unions
*/
template <typename T>
struct __align__(16) uninitialized
{
/// Backing storage
uint8_t storage[sizeof(T)];
/// Alias
__host__ __device__ __forceinline__ T& alias()
{
return reinterpret_cast<T&>(*this);
}
};
} // namespace cutlass

7
cutlass_test/.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
/bin/
/gemm-GPU.csv
/gemm-REF.csv
/a.csv
/b.csv
/gp100_schmoo/
/ignore/

213
cutlass_test/Makefile Normal file
View File

@@ -0,0 +1,213 @@
#/******************************************************************************
# * Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
# *
# * Redistribution and use in source and binary forms, with or without
# * modification, are not permitted.
# *
# * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
# * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
# * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# *
#******************************************************************************/
#-------------------------------------------------------------------------------
#
# Makefile usage
#
# make <target> sm=<XX[,YY,ZZ,..]> [transpose=<nn*|nt|tn|tt>] [verbose=<0*|1>] [keep=<0*|1>]
#
# * : default
#
#-------------------------------------------------------------------------------
TEST_DIR := $(dir $(lastword $(MAKEFILE_LIST)))
include ../common.mk
#-------------------------------------------------------------------------------
# Commandline Options
#-------------------------------------------------------------------------------
ifdef transpose
TRANSPOSE := $(transpose)
else
TRANSPOSE := nn
endif
ifdef deepbench
BENCHMARK_DEEPBENCH := $(deepbench)
else
BENCHMARK_DEEPBENCH := 0
endif
# If defined, GEMMs only compiled with specified alignment restrictions on A and B
# matrices. Otherwise, kernels are compiled for all feasible alignment options, and
# the appropriate kernel is selected.
ifdef alignment
DEFINES += -DGEMM_ALIGNMENT=$(alignment)
endif
# If defined as false, ragged handling can be disabled.
ifdef ragged
DEFINES += -DGEMM_RAGGED=$(ragged)
endif
#-------------------------------------------------------------------------------
# Include and Library paths
#-------------------------------------------------------------------------------
INC += -I$(TEST_DIR)
INC += -I$(BASE_DIR)
LIBS += -lcublas
#-------------------------------------------------------------------------------
# Preprocessor definitions
#-------------------------------------------------------------------------------
ifeq (nt, $(TRANSPOSE))
DEFINES += -DTRANSPOSE_B
else ifeq (tn, $(TRANSPOSE))
DEFINES += -DTRANSPOSE_A
else ifeq (tt, $(TRANSPOSE))
DEFINES += -DTRANSPOSE_A
DEFINES += -DTRANSPOSE_B
endif
NVCCFLAGS += -std=c++11
#-------------------------------------------------------------------------------
# Dependency Lists
#-------------------------------------------------------------------------------
DEPS := $(call rwildcard, $(BASE_DIR),*.h) \
$(call rwildcard, $(BASE_DIR)cgl,*.h) \
$(BASE_DIR)common.mk \
$(TEST_DIR)Makefile
ALL := sgemm \
dgemm \
hgemm \
igemm
#-------------------------------------------------------------------------------
# make default
#-------------------------------------------------------------------------------
default:
#-------------------------------------------------------------------------------
# make clean
#-------------------------------------------------------------------------------
clean :
rm -f bin/*
rm -f *.i* *.cubin *.cu.c *.cudafe* *.fatbin.c *.ptx *.hash *.cu.cpp *.o *.obj* *dlink.* *.res *.fatbin *.module_id
#-------------------------------------------------------------------------------
# make all
#-------------------------------------------------------------------------------
all : $(ALL)
#-------------------------------------------------------------------------------
# make sgemm
#-------------------------------------------------------------------------------
sgemm: bin/sgemm_$(TRANSPOSE)_$(BIN_SUFFIX)
sgemm_testbench: bin/sgemm_testbench_$(BIN_SUFFIX)
bin/sgemm_$(TRANSPOSE)_$(BIN_SUFFIX) : gemm.cu $(DEPS)
mkdir -p bin
$(NVCC) -DTEST_SGEMM $(DEFINES) $(SM_TARGETS) -o $@ gemm.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
bin/sgemm_testbench_$(BIN_SUFFIX) : gemm_testbench.cu $(DEPS)
mkdir -p bin
$(NVCC) -DTEST_SGEMM $(DEFINES) $(SM_TARGETS) -D BENCHMARK_DEEPBENCH=$(BENCHMARK_DEEPBENCH) -o $@ gemm_testbench.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
#-------------------------------------------------------------------------------
# make dgemm
#-------------------------------------------------------------------------------
dgemm: bin/dgemm_$(TRANSPOSE)_$(BIN_SUFFIX)
dgemm_testbench: bin/dgemm_testbench_$(BIN_SUFFIX)
bin/dgemm_$(TRANSPOSE)_$(BIN_SUFFIX) : gemm.cu $(DEPS)
mkdir -p bin
$(NVCC) -DTEST_DGEMM $(DEFINES) $(SM_TARGETS) -o $@ gemm.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
bin/dgemm_testbench_$(BIN_SUFFIX) : gemm_testbench.cu $(DEPS)
mkdir -p bin
$(NVCC) -DTEST_DGEMM $(DEFINES) $(SM_TARGETS) -D BENCHMARK_DEEPBENCH=$(BENCHMARK_DEEPBENCH) -o $@ gemm_testbench.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
#-------------------------------------------------------------------------------
# make hgemm
#-------------------------------------------------------------------------------
hgemm: bin/hgemm_$(TRANSPOSE)_$(BIN_SUFFIX)
hgemm_testbench: bin/hgemm_testbench_$(BIN_SUFFIX)
bin/hgemm_$(TRANSPOSE)_$(BIN_SUFFIX) : gemm.cu $(DEPS)
mkdir -p bin
$(NVCC) -DTEST_HGEMM $(DEFINES) $(SM_TARGETS) -o $@ gemm.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
bin/hgemm_testbench_$(BIN_SUFFIX) : gemm_testbench.cu $(DEPS)
mkdir -p bin
$(NVCC) -DTEST_HGEMM $(DEFINES) $(SM_TARGETS) -D BENCHMARK_DEEPBENCH=$(BENCHMARK_DEEPBENCH) -o $@ gemm_testbench.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
#-------------------------------------------------------------------------------
# make igemm
#-------------------------------------------------------------------------------
igemm: bin/igemm_$(TRANSPOSE)_$(BIN_SUFFIX)
bin/igemm_$(TRANSPOSE)_$(BIN_SUFFIX) : gemm.cu $(DEPS)
mkdir -p bin
$(NVCC) -DTEST_IGEMM $(DEFINES) $(SM_TARGETS) -o $@ gemm.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
igemm_testbench: bin/igemm_testbench_$(BIN_SUFFIX)
bin/igemm_testbench_$(BIN_SUFFIX) : gemm_testbench.cu $(DEPS)
mkdir -p bin
$(NVCC) -DTEST_IGEMM $(DEFINES) $(SM_TARGETS) -D BENCHMARK_DEEPBENCH=$(BENCHMARK_DEEPBENCH) -o $@ gemm_testbench.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
#-------------------------------------------------------------------------------
# make wgemm
#-------------------------------------------------------------------------------
wgemm: bin/wgemm_$(TRANSPOSE)_$(BIN_SUFFIX)
wgemm_testbench: bin/wgemm_testbench_$(BIN_SUFFIX)
bin/wgemm_$(TRANSPOSE)_$(BIN_SUFFIX) : gemm.cu $(DEPS)
mkdir -p bin
$(NVCC) -DTEST_WGEMM -DWMMA $(DEFINES) $(SM_TARGETS) -o $@ gemm.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)
bin/wgemm_testbench_$(BIN_SUFFIX) : gemm_testbench.cu $(DEPS)
mkdir -p bin
$(NVCC) -DTEST_WGEMM -DWMMA $(DEFINES) $(SM_TARGETS) -D BENCHMARK_DEEPBENCH=$(BENCHMARK_DEEPBENCH) -o $@ gemm_testbench.cu $(NVCCFLAGS) $(CPU_ARCH) $(INC) $(LIBINC) $(LIBS)

View File

@@ -0,0 +1,292 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* C++ interface for dispatching CUBLAS GEMM calls
*/
#include <cublas_v2.h>
namespace cutlass {
/******************************************************************************
* cuBLAS dispatch entrypoints
******************************************************************************/
/**
* Dispatch cuBLAS igemm
*/
cublasStatus_t cublas_gemm_dispatch(
cublasHandle_t cublas_handle, ///< CUBLAS handle
cublasOperation_t transform_a, ///< Transform op(A) that is non- or (conj.) transpose.
cublasOperation_t transform_b, ///< Transform op(B) that is non- or (conj.) transpose.
int m, ///< Height in rows of op(A) and C
int n, ///< Width in columns of op(B) and C
int k, ///< Width in columns of op(A) and height in rows of op(B)
int32_t alpha, ///< Scalar used for multiplicands
int8_t *d_a, ///< Device pointer to matrix A array values
int8_t *d_b, ///< Device pointer to matrix B array values
int32_t beta, ///< Scalar used for addend
int32_t *d_c, ///< Device pointer to matrix C array values
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to check for errors.
{
return cublasGemmEx(
cublas_handle,
transform_a,
transform_b,
m,
n,
k,
(void*) &alpha,
(void*) d_a,
CUDA_R_8I,
(transform_a == CUBLAS_OP_N) ? m : k,
(void*) d_b,
CUDA_R_8I,
(transform_b == CUBLAS_OP_N) ? k : n,
(void*) &beta,
(void*) d_c,
CUDA_R_32I,
m,
CUDA_R_32I,
CUBLAS_GEMM_DFALT);
}
/**
* Dispatch cuBLAS hgemm
*/
cublasStatus_t cublas_gemm_dispatch(
cublasHandle_t cublas_handle, ///< CUBLAS handle
cublasOperation_t transform_a, ///< Transform op(A) that is non- or (conj.) transpose.
cublasOperation_t transform_b, ///< Transform op(B) that is non- or (conj.) transpose.
int m, ///< Height in rows of op(A) and C
int n, ///< Width in columns of op(B) and C
int k, ///< Width in columns of op(A) and height in rows of op(B)
__half alpha, ///< Scalar used for multiplicands
__half *d_a, ///< Device pointer to matrix A array values
__half *d_b, ///< Device pointer to matrix B array values
__half beta, ///< Scalar used for addend
__half *d_c, ///< Device pointer to matrix C array values
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to check for errors.
{
return cublasHgemm(
cublas_handle, transform_a, transform_b,
m, n, k,
&alpha,
d_a,
(transform_a == CUBLAS_OP_N) ? m : k,
d_b,
(transform_b == CUBLAS_OP_N) ? k : n,
&beta,
d_c,
m);
}
/**
* Dispatch cuBLAS sgemm
*/
cublasStatus_t cublas_gemm_dispatch(
cublasHandle_t cublas_handle, ///< CUBLAS handle
cublasOperation_t transform_a, ///< Transform op(A) that is non- or (conj.) transpose.
cublasOperation_t transform_b, ///< Transform op(B) that is non- or (conj.) transpose.
int m, ///< Height in rows of op(A) and C
int n, ///< Width in columns of op(B) and C
int k, ///< Width in columns of op(A) and height in rows of op(B)
float alpha, ///< Scalar used for multiplicands
float *d_a, ///< Device pointer to matrix A array values
float *d_b, ///< Device pointer to matrix B array values
float beta, ///< Scalar used for addend
float *d_c, ///< Device pointer to matrix C array values
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to check for errors.
{
return cublasSgemm(
cublas_handle, transform_a, transform_b,
m, n, k,
&alpha,
d_a,
(transform_a == CUBLAS_OP_N) ? m : k,
d_b,
(transform_b == CUBLAS_OP_N) ? k : n,
&beta,
d_c,
m);
}
/**
* Dispatch cuBLAS dgemm
*/
cublasStatus_t cublas_gemm_dispatch(
cublasHandle_t cublas_handle, ///< CUBLAS handle
cublasOperation_t transform_a, ///< Transform op(A) that is non- or (conj.) transpose.
cublasOperation_t transform_b, ///< Transform op(B) that is non- or (conj.) transpose.
int m, ///< Height in rows of op(A) and C
int n, ///< Width in columns of op(B) and C
int k, ///< Width in columns of op(A) and height in rows of op(B)
double alpha, ///< Scalar used for multiplicands
double *d_a, ///< Device pointer to matrix A array values
double *d_b, ///< Device pointer to matrix B array values
double beta, ///< Scalar used for addend
double *d_c, ///< Device pointer to matrix C array values
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to check for errors.
{
return cublasDgemm(
cublas_handle, transform_a, transform_b,
m, n, k,
&alpha,
d_a, (transform_a == CUBLAS_OP_N) ? m : k,
d_b, (transform_b == CUBLAS_OP_N) ? k : n,
&beta,
d_c, m);
}
/**
* Dispatch cuBLAS Tensor Cores GEMM
*/
cublasStatus_t cublas_gemm_dispatch(
cublasHandle_t cublas_handle, ///< CUBLAS handle
cublasOperation_t transform_a, ///< Transform op(A) that is non- or (conj.) transpose.
cublasOperation_t transform_b, ///< Transform op(B) that is non- or (conj.) transpose.
int m, ///< Height in rows of op(A) and C
int n, ///< Width in columns of op(B) and C
int k, ///< Width in columns of op(A) and height in rows of op(B)
float alpha, ///< Scalar used for multiplicands
half *d_a, ///< Device pointer to matrix A array values
half *d_b, ///< Device pointer to matrix B array values
float beta, ///< Scalar used for addend
float *d_c, ///< Device pointer to matrix C array values
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to check for errors.
{
return cublasGemmEx(
cublas_handle,
transform_a,
transform_b,
m,
n,
k,
(void*) &alpha,
(void*) d_a,
CUDA_R_16F,
(transform_a == CUBLAS_OP_N) ? m : k,
(void*) d_b,
CUDA_R_16F,
(transform_b == CUBLAS_OP_N) ? k : n,
(void*) &beta,
(void*) d_c,
CUDA_R_32F,
m,
CUDA_R_32F,
CUBLAS_GEMM_DFALT_TENSOR_OP);
}
/**
* Uses cuBLAS to compute gemm on device matrices (unspecialized)
*/
template <
gemm::tiling_strategy::kind_t _TilingStrategy, ///< Tile-sizing classification category
math_operation_class_t _math_op,
matrix_transform_t::kind_t _TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t _TransformB, ///< Transformation op for matrix B
typename _value, ///< Multiplicand value type (matrices A and B)
typename _accum ///< Accumulator value type (matrix C and scalars)
>
struct cublas_gemm
{
//
// Type alias definitions
//
static const gemm::tiling_strategy::kind_t TilingStrategy = _TilingStrategy;
static const math_operation_class_t math_op = _math_op;
static const matrix_transform_t::kind_t TransformA = _TransformA;
static const matrix_transform_t::kind_t TransformB = _TransformB;
using value_t = _value;
using accum_t = _accum;
/// Launches a GEMM
gemm::launch_configuration operator()(
cublasHandle_t cublas_handle, ///< CUBLAS handle
int m,
int n,
int k,
value_t *A, ///< A matrix
value_t *B, ///< B matrix
accum_t *C, ///< C matrix
accum_t alpha, ///< Scalar used for multiplicands
accum_t beta, ///< Scalar used for addend
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous = false) ///< Whether or not to synchronize the stream after every kernel launch to check for errors.
{
cublasStatus_t cublas_error = cublas_gemm_dispatch(
cublas_handle,
(cublasOperation_t) TransformA,
(cublasOperation_t) TransformB,
m,
n,
k,
alpha,
A,
B,
beta,
C,
stream,
debug_synchronous);
cudaError_t error;
if (cublas_error != CUBLAS_STATUS_SUCCESS)
{
if (cublas_error == CUBLAS_STATUS_NOT_SUPPORTED) {
return gemm::launch_configuration(cudaErrorInvalidValue);
}
error = cudaGetLastError();
if (error == cudaSuccess) {
return gemm::launch_configuration(cudaErrorUnknown);
}
return error;
}
// Check for failure to launch
if (CUDA_PERROR_DEBUG(error = cudaPeekAtLastError()))
return gemm::launch_configuration(error);
// Sync the stream if specified to flush runtime errors
if (debug_synchronous && (CUDA_PERROR_DEBUG(error = cudaStreamSynchronize(stream))))
return gemm::launch_configuration(error);
return gemm::launch_configuration(error);
}
};
} // namespace cutlass

View File

@@ -0,0 +1,253 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file Dispatch routines for CUTLASS GEMM kernels
*/
// CUDA includes
#include <cublas_v2.h>
// Cutlass GEMM API
#include <cutlass/util/util.h>
#include <cutlass/gemm/dispatch.h>
#include <cutlass/gemm/epilogue_function.h>
// Test utilities
#include "util/type_conversion.h"
namespace cutlass {
/******************************************************************************
* Cutlass dispatch entrypoints
******************************************************************************/
//
// Compile-time overrides for alignment and ragged handling.
//
// If zero, all feasible alignment options are supported.
#ifndef GEMM_ALIGNMENT
#define GEMM_ALIGNMENT 0
#endif
// If true, kernels are compiled with ragged handling enabled.
#ifndef GEMM_RAGGED
#define GEMM_RAGGED true
#endif
//
// Dispatch logic given problem size specialization, math operation class, layout
// and type of operands, and epilogue operation.
//
/**
* Cutlass GEMM dispatch
*/
template <
gemm::tiling_strategy::kind_t _TilingStrategy, ///< Tile-sizing classification category
math_operation_class_t _math_op, // Indicates
matrix_transform_t::kind_t _TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t _TransformB, ///< Transformation op for matrix B
typename _value, ///< Multiplicand value type (matrices A and B)
typename _accum, ///< Accumulator value type (matrix C and scalars)
typename _epilogue_op_t ///< Epilogue opeartion to update matrix C
= gemm::blas_scaled_epilogue<_accum, _accum, _accum>
>
struct cutlass_gemm_dispatch
{
//
// Type alias definitions
//
static const gemm::tiling_strategy::kind_t TilingStrategy = _TilingStrategy;
static const math_operation_class_t math_op = _math_op;
static const matrix_transform_t::kind_t TransformA = _TransformA;
static const matrix_transform_t::kind_t TransformB = _TransformB;
using value_t = _value;
using accum_t = _accum;
using epilogue_op_t = _epilogue_op_t;
//
// Methods
//
/// Returns leading dimension for A matrix operand
int leading_dim_a(int m, int k) const
{
return (TransformA == matrix_transform_t::NonTranspose ? m : k);
}
/// Returns leading dimension for B matrix operand
int leading_dim_b(int k, int n) const
{
return (TransformB == matrix_transform_t::NonTranspose ? k : n);
}
/// Launches a GEMM
template <int operand_alignment, int accumulator_alignment>
gemm::launch_configuration launch(
int m,
int n,
int k,
epilogue_op_t epilogue_op,
value_t *A,
value_t *B,
accum_t *C,
cudaStream_t stream = 0,
bool debug_synchronous = false)
{
return gemm::device_gemm<
TilingStrategy,
math_op,
TransformA,
operand_alignment,
TransformB,
operand_alignment,
value_t,
accum_t,
epilogue_op_t,
accumulator_alignment>
(
m,
n,
k,
epilogue_op,
A,
B,
C,
stream,
debug_synchronous);
}
/// Dispatches a CUTLASS GEMM
gemm::launch_configuration operator()(
cublasHandle_t handle, ///< CUBLAS handle
int m, ///< Rows of GEMM problem
int n, ///< Columns of GEMM problem
int k, ///< Inner dimension of GEMM problem
value_t *A, ///< A matrix
value_t *B, ///< B matrix
accum_t *C, ///< C matrix
accum_t alpha, ///< Scalar used for multiplicands
accum_t beta, ///< Scalar used for addend
cudaStream_t stream = 0, ///< CUDA stream to launch kernels within.
bool debug_synchronous = false) ///< Whether or not to synchronize the stream
/// after every kernel launch to check for errors.
{
// Forces kernel selection to choose specific alignment (in bytes)
int const force_operand_alignment = GEMM_ALIGNMENT;
// Problem size must be multiple of the smallest vector load size
typedef value_t operand_load_t;
int const accumulator_alignment = sizeof(accum_t);
int const lda = leading_dim_a(m, k);
int const ldb = leading_dim_b(k, n);
epilogue_op_t epilogue(alpha, beta);
// TODO: opportunity for metaprogramming loop
// Prefer the largest granularity of vector load that is compatible with
// problem size and data alignment.
if ((!force_operand_alignment || force_operand_alignment == 16) &&
!((sizeof(operand_load_t) * lda) % 16) &&
!((sizeof(operand_load_t) * ldb) % 16))
{
#if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 16)
return launch<__NV_STD_MAX(16, sizeof(value_t)), accumulator_alignment>(
m,
n,
k,
epilogue,
A,
B,
C,
stream,
debug_synchronous);
#endif
}
else if ((!force_operand_alignment || force_operand_alignment == 8) &&
!((sizeof(operand_load_t) * lda) % 8) &&
!((sizeof(operand_load_t) * ldb) % 8))
{
#if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 8)
return launch<__NV_STD_MAX(8, sizeof(value_t)), accumulator_alignment>(
m,
n,
k,
epilogue,
A,
B,
C,
stream,
debug_synchronous);
#endif
}
else if ((!force_operand_alignment || force_operand_alignment == 4) &&
!((sizeof(operand_load_t) * lda) % 4) &&
!((sizeof(operand_load_t) * ldb) % 4))
{
#if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 4)
return launch<__NV_STD_MAX(4, sizeof(value_t)), accumulator_alignment>(
m,
n,
k,
epilogue,
A,
B,
C,
stream,
debug_synchronous);
#endif
}
else if ((!force_operand_alignment || force_operand_alignment == 2) &&
!((sizeof(operand_load_t) * lda) % 2) &&
!((sizeof(operand_load_t) * ldb) % 2))
{
// 16-bit alignment only supported for HGEMM
#if defined(TEST_HGEMM) || defined(TEST_WGEMM)
#if !(GEMM_ALIGNMENT) || (GEMM_ALIGNMENT == 2)
return launch<__NV_STD_MAX(2, sizeof(value_t)), accumulator_alignment>(
m,
n,
k,
epilogue,
A,
B,
C,
stream,
debug_synchronous);
#endif
#endif
}
return gemm::launch_configuration(cudaErrorInvalidValue);
}
};
} // namespace cutlass

564
cutlass_test/gemm.cu Normal file
View File

@@ -0,0 +1,564 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
/**
* \file gemm.cu
* GEMM test driver
*
*/
#include <iostream>
#include <typeinfo>
#include <random>
#include <stdint.h>
// CUBLAS GEMM API
#include <cublas_v2.h>
// Set Cutlass debug macro to enable console printing of library errors
#define DEBUG
#if defined(WMMA)
// Conditionally include WMMA headers (CUDA 9 Preview Feature)
#include <mma.h>
#endif
// Cutlass GEMM API
#include <cutlass/util/util.h>
#include <cutlass/gemm/dispatch.h>
#include <cutlass/gemm/epilogue_function.h>
// Test utilities
#include "util/command_line.h"
#include "util/half.h"
#include "util/matrix.h"
#include "util/timer.h"
#include "util/type_conversion.h"
// Dispatch routines to CUBLAS and CUTLASS
#include "cublas_dispatch.h"
#include "cutlass_dispatch.h"
/******************************************************************************
* Globals, constants and typedefs
******************************************************************************/
using namespace cutlass;
/// CUBLAS handle
cublasHandle_t g_cublas_handle;
/// The device-id of the current device
int g_device_id = -1;
/// The number of timing iterations to invoke
int g_timing_iterations = -1;
/// The number of randomly-sized problems to schmoo
int g_schmoo = 0;
/******************************************************************************
* Number generation
******************************************************************************/
/**
* Simple low-integer generator
*/
struct simple_gen
{
std::default_random_engine generator;
std::uniform_int_distribution<int> distribution;
/// Constructor
simple_gen(int max) : distribution(max * -1, max)
{}
/// Functor
int operator()()
{
return distribution(generator);
}
};
/******************************************************************************
* Test execution
******************************************************************************/
/**
* Compute C = (alpha * A * B) + (beta * C)
*/
template <
typename test_func_t, ///< Test function type
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
typename value_t, ///< Multiplicand value type (matrices A and B)
typename accum_t> ///< Accumulator value type (matrix C and scalars)
bool test(
int m, ///< Height of C in rows
int n, ///< Width of C in columns
int k, ///< Width (height) of A (B)
accum_t alpha, ///< Multiplicand scalar
accum_t beta) ///< Addend scalar
{
cudaStream_t stream = 0;
//
// Initialize matrices
//
matrix<value_t> A(
(TransformA == matrix_transform_t::NonTranspose) ? m : k,
(TransformA == matrix_transform_t::NonTranspose) ? k : m);
matrix<value_t> B(
(TransformB == matrix_transform_t::NonTranspose) ? k : n,
(TransformB == matrix_transform_t::NonTranspose) ? n : k);
matrix<accum_t> C(m, n);
// initialized matrices with small values precisely representable as integers
simple_gen a_gen(3);
simple_gen b_gen(5);
A.fill_random(a_gen);
B.fill_random(b_gen);
C.fill_ramp(0,0);
// // Alternatively, initialize with procedural values to simplify debugging incorrect results
// A.fill_ramp(1,2);
// B.fill_ramp(1,1);
// Sync to device
A.sync_device();
B.sync_device();
C.sync_device();
CUDA_PERROR(cudaPeekAtLastError());
CUDA_PERROR(cudaDeviceSynchronize());
//
// Run test once with debug-synchronous enabled and check result
//
if (!g_schmoo) printf("\n");
test_func_t test_func;
C.fill_ramp(0, 0);
C.sync_device();
cudaError_t error = test_func(
g_cublas_handle,
m,
n,
k,
A.d_data(),
B.d_data(),
C.d_data(),
alpha,
beta,
stream,
!g_schmoo).result;
bool not_applicable = (error == cudaErrorInvalidValue);
bool is_failed = false;
if (not_applicable)
{
printf(", NA");
}
else
{
CUDA_PERROR(error);
// Compute reference check if wont take too long on CPU
if ((!g_schmoo) && (m * n <= 1024 * 1024))
{
matrix<accum_t> ref_C(m, n);
ref_C.fill_ramp(0, 0);
ref_C.gemm(TransformA, TransformB, alpha, A, B, beta);
C.sync_host();
is_failed = (C != ref_C);
if (!g_schmoo)
{
if (is_failed)
{
printf("FAIL, ");
std::ofstream file_a("a.csv");
A.write_matrix(file_a);
std::ofstream file_b("b.csv");
B.write_matrix(file_b);
std::ofstream file_d("gemm-REF.csv");
ref_C.write_matrix(file_d);
std::ofstream file_c("gemm-GPU.csv");
C.write_matrix(file_c);
}
else
{
printf("PASS, ");
}
}
}
fflush(stdout);
//
// Warmup and timing iterations
//
if (g_timing_iterations > 0)
{
// Warmup for 1/100 of the timing iterations (minimum of 2)
for (int i = 0; i < __NV_STD_MAX(2, (g_timing_iterations + 99) / 100); ++i)
{
CUDA_PERROR(test_func(
g_cublas_handle,
m,
n,
k,
A.d_data(),
B.d_data(),
C.d_data(),
alpha,
beta,
stream,
false).result);
}
}
// Conduct timing iterations
double elapsed_ms = 0;
gpu_timer timer;
timer.start();
for (int i = 0; i < g_timing_iterations; i++)
{
CUDA_PERROR(test_func(
g_cublas_handle,
m,
n,
k,
A.d_data(),
B.d_data(),
C.d_data(),
alpha,
beta,
stream,
false).result);
}
timer.stop();
elapsed_ms += timer.elapsed_millis();
double avg_ms = elapsed_ms / g_timing_iterations;
// Display performance
if (g_timing_iterations > 0)
{
int64_t num_flops = (2 * int64_t(m) * int64_t(n) * int64_t(k)) + (2 * int64_t(m) * int64_t(n));
double gflops_per_sec = double(num_flops) / avg_ms / 1.0e6;
if (g_schmoo)
{
if (is_failed)
printf("F");
printf(", %.3f", gflops_per_sec);
// Sleep for a few milliseconds to cool
sleep_millis(10);
}
else
{
printf("Avg runtime: %.3f ms, total flops: %lld, GFLOP/s: %.2f\n",
avg_ms,
num_flops,
gflops_per_sec);
}
fflush(stdout);
}
}
return is_failed;
}
/**
* Compute C = (alpha * A * B) + (beta * C)
*/
template <
math_operation_class_t math_op,
matrix_transform_t::kind_t TransformA, ///< Transformation op for matrix A
matrix_transform_t::kind_t TransformB, ///< Transformation op for matrix B
typename value_t, ///< Multiplicand value type (matrices A and B)
typename accum_t> ///< Accumulator value type (matrix C and scalars)
bool test(
int m, ///< Height of C in rows
int n, ///< Width of C in columns
int k, ///< Width (height) of A (B)
accum_t alpha, ///< Multiplicand scalar
accum_t beta) ///< Addend scalar
{
uint64_t flop_base = 1ull << 41;
int max_timing_iterations = 10000;
int min_timing_iterations = 10;
bool test_error = false;
// Scale the number of timing iterations with respect to problem size (if not specified on commandline)
if ((g_timing_iterations < 0) || g_schmoo)
{
uint64_t num_flops = (2 * uint64_t(m) * uint64_t(n) * uint64_t(k)) + (2 * uint64_t(m) * uint64_t(n));
g_timing_iterations = (int) ((flop_base / sizeof(value_t)) / num_flops);
g_timing_iterations = (int) __NV_STD_MIN(max_timing_iterations, g_timing_iterations);
g_timing_iterations = (int) __NV_STD_MAX(min_timing_iterations, g_timing_iterations);
}
if (g_schmoo)
{
printf("%d, %d, %d, %c%c, %d, %d",
m, n, k,
(TransformA == matrix_transform_t::NonTranspose) ? 'n' : 't',
(TransformB == matrix_transform_t::NonTranspose) ? 'n' : 't',
m * n,
g_timing_iterations);
}
else
{
printf("\n------------------------------------------------------------\n");
printf("%dx%dx%d, GEMM_%c%c, %d C elements, %d timing iterations\n",
m, n, k,
(TransformA == matrix_transform_t::NonTranspose) ? 'n' : 't',
(TransformB == matrix_transform_t::NonTranspose) ? 'n' : 't',
m * n,
g_timing_iterations);
}
fflush(stdout);
// CUBLAS
test_error |= test<
cublas_gemm<gemm::tiling_strategy::Unknown, math_op, TransformA, TransformB, value_t, accum_t>,
TransformA,
TransformB,
value_t,
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
// CUTLASS
test_error |= test<
cutlass_gemm_dispatch<gemm::tiling_strategy::Small, math_op, TransformA, TransformB, value_t, accum_t>,
TransformA,
TransformB,
value_t,
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
test_error |= test<
cutlass_gemm_dispatch<gemm::tiling_strategy::Medium, math_op, TransformA, TransformB, value_t, accum_t>,
TransformA,
TransformB,
value_t,
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
test_error |= test<
cutlass_gemm_dispatch<gemm::tiling_strategy::Large, math_op, TransformA, TransformB, value_t, accum_t>,
TransformA,
TransformB,
value_t,
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
test_error |= test<
cutlass_gemm_dispatch<gemm::tiling_strategy::Tall, math_op, TransformA, TransformB, value_t, accum_t>,
TransformA,
TransformB,
value_t,
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
test_error |= test<
cutlass_gemm_dispatch<gemm::tiling_strategy::Wide, math_op, TransformA, TransformB, value_t, accum_t>,
TransformA,
TransformB,
value_t,
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
test_error |= test<
cutlass_gemm_dispatch<gemm::tiling_strategy::Huge, math_op, TransformA, TransformB, value_t, accum_t>,
TransformA,
TransformB,
value_t,
accum_t>(m, n, k, accum_t(alpha), accum_t(beta));
return test_error;
}
/******************************************************************************
* Main
******************************************************************************/
/**
* Main
*/
int main(int argc, const char **argv)
{
//
// Problem type (compiler-supplied so we don't compile everything)
//
// Define value_t and accum_t (multiplicand and accumulator types, respectively)
#if defined(TEST_SGEMM)
typedef float value_t;
typedef float accum_t;
const math_operation_class_t math_op = math_operation_class_t::scalar;
#elif defined(TEST_DGEMM)
typedef double value_t;
typedef double accum_t;
const math_operation_class_t math_op = math_operation_class_t::scalar;
#elif defined(TEST_HGEMM)
typedef __half value_t;
typedef __half accum_t;
const math_operation_class_t math_op = math_operation_class_t::scalar;
#elif defined(TEST_IGEMM)
typedef int8_t value_t;
typedef int32_t accum_t;
const math_operation_class_t math_op = math_operation_class_t::scalar;
#elif defined(TEST_WGEMM)
typedef half value_t;
typedef float accum_t;
const math_operation_class_t math_op = math_operation_class_t::matrix;
#else
#error Unknown GEMM type requested.
#endif
// Define transpose constants
#ifdef TRANSPOSE_A
static const matrix_transform_t::kind_t TransformA = matrix_transform_t::Transpose;
#else
static const matrix_transform_t::kind_t TransformA = matrix_transform_t::NonTranspose;
#endif
#ifdef TRANSPOSE_B
static const matrix_transform_t::kind_t TransformB = matrix_transform_t::Transpose;
#else
static const matrix_transform_t::kind_t TransformB = matrix_transform_t::NonTranspose;
#endif
//
// Commandline parsing
//
// Initialize command line
command_line args(argc, argv);
int m_factor = args.device_prop.multiProcessorCount * 128;
int m = round_nearest(4096, m_factor);
int k = 4096;
int n = 4096;
float alpha = 1.0;
float beta = 0.0;
g_device_id = args.device_id;
args.get_cmd_line_argument("m", m);
args.get_cmd_line_argument("n", n);
args.get_cmd_line_argument("k", k);
args.get_cmd_line_argument("i", g_timing_iterations);
args.get_cmd_line_argument("alpha", alpha);
args.get_cmd_line_argument("beta", beta);
args.get_cmd_line_argument("schmoo", g_schmoo);
// Print usage
if (args.check_cmd_line_flag("help"))
{
printf("%s "
"[--help] "
"[--i=<timing iterations>] "
"[--device=<device-id>] "
"[--alpha=<alpha> --beta=<beta>] "
"[--schmoo=<samples> || --m=<height> --n=<width> --k=<depth>]"
"\n", argv[0]);
exit(0);
}
// Initialize cuBLAS
if (cublasCreate(&g_cublas_handle) != CUBLAS_STATUS_SUCCESS)
{
fprintf(stderr, "cublasCreate() failed\n");
exit(1);
}
bool test_error = false;
if (g_schmoo)
{
// Run a schmoo of problem sizes
printf("M, N, K, transpose, total_flops, timing_iterations, sol_flop/s, cublas_sol, cutlass_small_sol, cutlass_med_sol, cutlass_large_sol, cutlass_tall_sol, cutlass_wide_sol, cutlass_huge_sol\n");
// Generate power-law distribution from [32, 16384)
std::mt19937 gen(0);
std::uniform_real_distribution<float> dis(5, 14);
for (int i = 0; i < g_schmoo; ++i)
{
int m = int(pow(float(2), dis(gen)));
int n = int(pow(float(2), dis(gen)));
int k = int(pow(float(2), dis(gen)));
// Round m and n to nearest multiple of 32 if < 128, otherwise to the nearest 128
m = (m < 128) ?
round_nearest(m, 32) :
round_nearest(m, 128);
n = (n < 128) ?
round_nearest(n, 32) :
round_nearest(n, 128);
// Round k to the nearest 16
k = (sizeof(value_t) == 1) ?
round_nearest(k, 32) :
round_nearest(k, 16);
test_error |= test<math_op, TransformA, TransformB, value_t, accum_t>(
m, n, k,
from_float<accum_t>(alpha),
from_float<accum_t>(beta));
printf("\n"); fflush(stdout);
}
}
else
{
// Test a single GEMM problem size
test_error |= test<math_op, TransformA, TransformB, value_t, accum_t>(
m,
n,
k,
from_float<accum_t>(alpha),
from_float<accum_t>(beta));
}
// Cleanup
cublasDestroy(g_cublas_handle);
return test_error;
}

View File

@@ -0,0 +1,312 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Utility for parsing command line arguments
*/
#include <string>
#include <vector>
#include <sstream>
#include <iostream>
#include <limits>
#include <cuda_runtime.h>
#include <cutlass/util/debug.h>
namespace cutlass {
/******************************************************************************
* command_line
******************************************************************************/
/**
* Utility for parsing command line arguments
*/
struct command_line
{
std::vector<std::string> keys;
std::vector<std::string> values;
std::vector<std::string> args;
int device_id;
cudaDeviceProp device_prop;
float device_giga_bandwidth;
size_t device_free_physmem;
size_t device_total_physmem;
/**
* Constructor
*/
command_line(int argc, const char **argv, int device_id = -1) :
keys(10),
values(10),
device_id(device_id)
{
using namespace std;
for (int i = 1; i < argc; i++)
{
string arg = argv[i];
if ((arg[0] != '-') || (arg[1] != '-'))
{
args.push_back(arg);
continue;
}
string::size_type pos;
string key, val;
if ((pos = arg.find('=')) == string::npos) {
key = string(arg, 2, arg.length() - 2);
val = "";
} else {
key = string(arg, 2, pos - 2);
val = string(arg, pos + 1, arg.length() - 1);
}
keys.push_back(key);
values.push_back(val);
}
// Initialize device
CUDA_PERROR_EXIT(device_init());
}
/**
* Checks whether a flag "--<flag>" is present in the commandline
*/
bool check_cmd_line_flag(const char* arg_name)
{
using namespace std;
for (int i = 0; i < int(keys.size()); ++i)
{
if (keys[i] == string(arg_name))
return true;
}
return false;
}
/**
* Returns number of naked (non-flag and non-key-value) commandline parameters
*/
template <typename value_t>
int num_naked_args()
{
return args.size();
}
/**
* Returns the commandline parameter for a given index (not including flags)
*/
template <typename value_t>
void get_cmd_line_argument(int index, value_t &val)
{
using namespace std;
if (index < args.size()) {
istringstream str_stream(args[index]);
str_stream >> val;
}
}
/**
* Returns the value specified for a given commandline parameter --<flag>=<value>
*/
template <typename value_t>
void get_cmd_line_argument(const char *arg_name, value_t &val)
{
using namespace std;
for (int i = 0; i < int(keys.size()); ++i)
{
if (keys[i] == string(arg_name))
{
istringstream str_stream(values[i]);
str_stream >> val;
}
}
}
/**
* Returns the values specified for a given commandline parameter --<flag>=<value>,<value>*
*/
template <typename value_t>
void get_cmd_line_arguments(
const char *arg_name,
std::vector<value_t> &vals,
char sep = ',')
{
using namespace std;
if (check_cmd_line_flag(arg_name))
{
// Clear any default values
vals.clear();
// Recover from multi-value string
for (int i = 0; i < keys.size(); ++i)
{
if (keys[i] == string(arg_name))
{
string val_string(values[i]);
istringstream str_stream(val_string);
string::size_type old_pos = 0;
string::size_type new_pos = 0;
// Iterate <sep>-delimited values
value_t val;
while ((new_pos = val_string.find(sep, old_pos)) != string::npos)
{
if (new_pos != old_pos)
{
str_stream.width(new_pos - old_pos);
str_stream >> val;
vals.push_back(val);
}
// skip over delimiter
str_stream.ignore(1);
old_pos = new_pos + 1;
}
// Read last value
str_stream >> val;
vals.push_back(val);
}
}
}
}
/**
* The number of pairs parsed
*/
int parsed_argc()
{
return (int) keys.size();
}
/**
* Initialize device
*/
cudaError_t device_init()
{
cudaError_t error = cudaSuccess;
do
{
int deviceCount;
if (CUDA_PERROR(error = cudaGetDeviceCount(&deviceCount))) break;
if (deviceCount == 0) {
fprintf(stderr, "No devices supporting CUDA.\n");
exit(1);
}
if (device_id < 0)
{
get_cmd_line_argument("device", device_id);
}
if ((device_id > deviceCount - 1) || (device_id < 0))
{
device_id = 0;
}
if (CUDA_PERROR(error = cudaSetDevice(device_id))) break;
if (CUDA_PERROR(error = cudaMemGetInfo(&device_free_physmem, &device_total_physmem))) break;
if (CUDA_PERROR(error = cudaGetDeviceProperties(&device_prop, device_id))) break;
if (device_prop.major < 1) {
fprintf(stderr, "Device does not support CUDA.\n");
exit(1);
}
device_giga_bandwidth = float(device_prop.memoryBusWidth) * device_prop.memoryClockRate * 2 / 8 / 1000 / 1000;
} while (0);
return error;
}
//-------------------------------------------------------------------------
// Utility functions
//-------------------------------------------------------------------------
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
static void tokenize(
std::vector<std::pair<std::string, std::string> > &tokens,
std::string const &str,
char delim = ',',
char sep = ':')
{
// Home-built to avoid Boost dependency
size_t s_idx = 0;
size_t d_idx = std::string::npos;
while (s_idx < str.size())
{
d_idx = str.find_first_of(delim, s_idx);
size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size());
size_t sep_idx = str.find_first_of(sep, s_idx);
size_t offset = 1;
if (sep_idx == std::string::npos || sep_idx >= end_idx)
{
sep_idx = end_idx;
offset = 0;
}
std::pair<std::string, std::string> item(
str.substr(s_idx, sep_idx - s_idx),
str.substr(sep_idx + offset, end_idx - sep_idx - offset));
tokens.push_back(item);
s_idx = end_idx + 1;
}
}
/// Tokenizes a comma-delimited list of string pairs delimited by ':'
static void tokenize(
std::vector<std::string > &tokens,
std::string const &str,
char delim = ',',
char sep = ':')
{
std::vector<std::pair<std::string, std::string> > token_pairs;
tokenize(token_pairs, str, delim, sep);
for (auto const &tok : token_pairs)
{
tokens.push_back(tok.first);
}
}
};
} // namespace cutlass

View File

@@ -0,0 +1,83 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* \brief C++ exception semantics for CUDA error codes
*/
#include <iosfwd>
#include <cuda_runtime.h>
namespace cutlass {
/// C++ exception wrapper for CUDA \p cudaError_t
class cuda_exception : public std::exception
{
public:
/// Constructor
cuda_exception(
const char *msg = "",
cudaError_t err = cudaErrorUnknown)
:
msg(msg), err(err)
{}
/// Returns the explanatory string
const char *what() const noexcept
{
return msg;
}
/// Returns the underlying CUDA \p cudaError_t
cudaError_t cudaError() const
{
return err;
}
protected:
/// Explanatory string
const char *msg;
/// Underlying CUDA \p cudaError_t
cudaError_t err;
};
/// Writes a cudaError_t to an output stream
inline std::ostream & operator<<(std::ostream &out, cudaError_t result)
{
return out << cudaGetErrorString(result);
}
/// Writes a cuda_exception instance to an output stream
inline std::ostream & operator<<(std::ostream &out, cuda_exception const &e)
{
return out << e.what() << ": " << e.cudaError();
}
} // namespace cutlass

224
cutlass_test/util/half.h Normal file
View File

@@ -0,0 +1,224 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Utilities for interacting with the opaque CUDA __half type
*/
#include <stdint.h>
#include <cuda_fp16.h>
#include <iosfwd>
namespace cutlass {
/******************************************************************************
* half_t
******************************************************************************/
/**
* Host-based fp16 data type compatible and convertible with __half
*/
struct half_t
{
uint16_t __x;
/// Constructor from __half
half_t(const __half &other)
{
__x = reinterpret_cast<const uint16_t&>(other);
}
/// Constructor from integer
half_t(int a)
{
*this = half_t(float(a));
}
/// Constructor from float
half_t(float a)
{
uint32_t ia = *reinterpret_cast<uint32_t*>(&a);
uint16_t ir;
ir = (ia >> 16) & 0x8000;
if ((ia & 0x7f800000) == 0x7f800000)
{
if ((ia & 0x7fffffff) == 0x7f800000)
{
ir |= 0x7c00; /* infinity */
}
else
{
ir = 0x7fff; /* canonical NaN */
}
}
else if ((ia & 0x7f800000) >= 0x33000000)
{
int32_t shift = (int32_t) ((ia >> 23) & 0xff) - 127;
if (shift > 15)
{
ir |= 0x7c00; /* infinity */
}
else
{
ia = (ia & 0x007fffff) | 0x00800000; /* extract mantissa */
if (shift < -14)
{ /* denormal */
ir |= ia >> (-1 - shift);
ia = ia << (32 - (-1 - shift));
}
else
{ /* normal */
ir |= ia >> (24 - 11);
ia = ia << (32 - (24 - 11));
ir = ir + ((14 + shift) << 10);
}
/* IEEE-754 round to nearest of even */
if ((ia > 0x80000000) || ((ia == 0x80000000) && (ir & 1)))
{
ir++;
}
}
}
this->__x = ir;
}
/// Cast to __half
operator __half() const
{
return reinterpret_cast<const __half&>(__x);
}
/// Cast to float
operator float() const
{
int sign = ((this->__x >> 15) & 1);
int exp = ((this->__x >> 10) & 0x1f);
int mantissa = (this->__x & 0x3ff);
uint32_t f = 0;
if (exp > 0 && exp < 31)
{
// normal
exp += 112;
f = (sign << 31) | (exp << 23) | (mantissa << 13);
}
else if (exp == 0)
{
if (mantissa)
{
// subnormal
exp += 113;
while ((mantissa & (1 << 10)) == 0)
{
mantissa <<= 1;
exp--;
}
mantissa &= 0x3ff;
f = (sign << 31) | (exp << 23) | (mantissa << 13);
}
else
{
// zero
f = 0;
}
}
else if (exp == 31)
{
if (mantissa)
{
f = 0x7fffffff; // not a number
}
else
{
f = (0xff << 23) | (sign << 31); // inf
}
}
return *reinterpret_cast<float const *>(&f);
}
/// Get raw storage
uint16_t raw()
{
return this->__x;
}
/// Assignment by sum
bool operator ==(const half_t &other)
{
return (this->__x == other.__x);
}
/// Increment
half_t& operator +=(const half_t &rhs)
{
*this = half_t(float(*this) + float(rhs));
return *this;
}
/// Decrement
half_t& operator -=(const half_t &rhs)
{
*this = half_t(float(*this) - float(rhs));
return *this;
}
/// Multiply
half_t operator*(const half_t &other)
{
return half_t(float(*this) * float(other));
}
/// Multiply
half_t operator+(const half_t &other)
{
return half_t(float(*this) + float(other));
}
};
/******************************************************************************
* I/O stream overloads
******************************************************************************/
/// Insert formatted \p half_t into the output stream
std::ostream& operator<<(std::ostream &out, const half_t &x)
{
out << (float)x;
return out;
}
/// Insert formatted \p __half into the output stream
std::ostream& operator<<(std::ostream &out, const __half &x)
{
return out << half_t(x);
}
} // namespace cutlass

495
cutlass_test/util/matrix.h Normal file
View File

@@ -0,0 +1,495 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* Matrix data structure providing basic CPU-based algorithms and
* operations that can be cloned and synchronized in GPU device memory
*/
#include <vector>
#include <fstream>
#include <cutlass/util/debug.h>
#include "../cutlass/util/matrix_transform.h"
#include "half.h"
namespace cutlass {
/**
* \brief Matrix data structure providing basic CPU-based algorithms and
* operations that be synchronized with a GPU-based replica
*/
template <typename value_t>
struct matrix
{
// Host value type (must be convertible to/from value_t)
typedef typename nv_std::conditional<
(nv_std::is_same<value_t, __half>::value), // If (value_t == __half) ...
half_t, // ... use half_t internally for host storage, else...
value_t>::type // ... use value_t directly
host_value_t;
//-----------------------------------------------------------------------------
// Data members
//-----------------------------------------------------------------------------
private:
/// M dimension (height in rows)
int _m;
/// N dimension (width in columns)
int _n;
/// Data array on host
std::vector<host_value_t> _h_data;
/// Clone of data array on GPU device
value_t *_d_data;
/// GPU Device identifier that clone synchronizes with
int _device_id;
public:
//-----------------------------------------------------------------------------
// Lifetime and synchronization
//-----------------------------------------------------------------------------
/**
* Constructor: zero-initializes the matrix.
*/
matrix(
int m, ///< Height of the matrix in rows
int n) ///< Width of the matrix in columns
:
_m(m),
_n(n),
_d_data(NULL),
_device_id(0)
{
_h_data.resize(_m * _n, 0);
CUDA_PERROR_EXIT(cudaMalloc((void ** )&_d_data, sizeof(value_t) * _m * _n));
CUDA_PERROR_EXIT(cudaGetDevice(&_device_id));
}
/// Destructor
~matrix()
{
if (_d_data)
{
CUDA_PERROR_EXIT(cudaFree(_d_data));
}
}
/**
* Synchronize the GPU-based replica with the current host-based matrix data
*/
void sync_device()
{
size_t bytes = _m * _n * sizeof(value_t);
CUDA_PERROR_EXIT(cudaMemcpy(_d_data, &_h_data[0], bytes, cudaMemcpyHostToDevice));
}
/**
* Synchronize the host-based replica with the current GPU-based matrix data
*/
void sync_host()
{
size_t bytes = _m * _n * sizeof(value_t);
CUDA_PERROR_EXIT(cudaMemcpy(&_h_data[0], _d_data, bytes, cudaMemcpyDeviceToHost));
}
//-----------------------------------------------------------------------------
// Inspectors
//-----------------------------------------------------------------------------
/**
* Return the height of the matrix, subject to the optional \p transpose_op
*/
int height(matrix_transform_t transpose_op = matrix_transform_t::NonTranspose) const
{
switch (transpose_op)
{
case matrix_transform_t::NonTranspose : return _m;
case matrix_transform_t::Transpose : return _n;
default: return -1;
}
}
/**
* Return the width of the matrix, subject to the optional \p transpose_op
*/
int width(matrix_transform_t transpose_op = matrix_transform_t::NonTranspose) const
{
switch (transpose_op)
{
case matrix_transform_t::NonTranspose : return _n;
case matrix_transform_t::Transpose : return _m;
default: return -1;
}
}
/**
* Return item at (x, y) coordinate of matrix, subject to the optional \p transform op
*/
host_value_t get(
int x,
int y,
matrix_transform_t transpose_op = matrix_transform_t::NonTranspose) const
{
switch (transpose_op)
{
case matrix_transform_t::NonTranspose : return _h_data[y + (x * _m)];
case matrix_transform_t::Transpose : return _h_data[x + (y * _m)];
default: return 0;
}
}
/**
* Return the distance (in items) within memory between elements of two
* consecutive columns which have the same row index, subject to the optional \p transform op
*/
int leading_dim(matrix_transform_t transpose_op = matrix_transform_t::NonTranspose) const
{
switch (transpose_op)
{
case matrix_transform_t::NonTranspose : return _m;
case matrix_transform_t::Transpose : return _n;
default: return 0;
}
}
/**
* Get host data pointer
*/
value_t* h_data()
{
return _h_data.data();
}
/**
* Get host data pointer
*/
value_t const* h_data() const
{
return _h_data.data();
}
/**
* Get device data pointer
*/
value_t const* d_data() const
{
return _d_data;
}
/**
* Get device data pointer
*/
value_t * d_data()
{
return _d_data;
}
//-----------------------------------------------------------------------------
// Initialization
//-----------------------------------------------------------------------------
/**
* Initialize matrix values with a 2D "ramp" defined as
* <tt>values(x, y) = (y * rs) + (x * cs)</tt>
*/
void fill_ramp(
host_value_t rs,
host_value_t cs)
{
for (int x = 0; x < _n; x++)
{
for (int y = 0; y < _m; y++)
{
_h_data[y + (x * _m)] = host_value_t((y * rs) + (x * cs));
}
}
}
/**
* Initialize matrix values such that all the elements of the principal diagonal
* are ones and all other elements are zeros
*/
void fill_identity()
{
for (int j = 0; j < _n; j++)
{
for (int i = 0; i < _m; i++)
{
_h_data[i + j * _m] = host_value_t(i == j ? 1 : 0);
}
}
}
/**
* Initialize matrix values using the random number \p generator. The
* \p generator reference is assumed to be a nullary functor that returns
* values convertible to the matrix \p value_t.
*/
template <typename T>
void fill_random(T & generator)
{
for (int j = 0; j < _n; j++)
{
for (int i = 0; i < _m; i++)
{
_h_data[i + j * _m] = (value_t) generator();
}
}
}
/**
* Element-wise matrix addition
*/
matrix & operator+=(matrix const &mat)
{
for (int j = 0; j < _n; j++)
{
for (int i = 0; i < _m; i++)
{
_h_data[i + j * _m] += mat._h_data[i + j * _m];
}
}
return *this;
}
/**
* Element-wise matrix subtraction
*/
matrix & operator-=(matrix const &mat)
{
for (int j = 0; j < _n; j++)
{
for (int i = 0; i < _m; i++)
{
_h_data[i + j * _m] -= mat._h_data[i + j * _m];
}
}
return *this;
}
//-----------------------------------------------------------------------------
// Output
//-----------------------------------------------------------------------------
/**
* Prints matrix in CSV to output stream
*/
template <typename _hv_t>
std::ostream & write_matrix(std::ostream &out, _hv_t)
{
for (int i = 0; i < _m; i++)
{
for (int j = 0; j < _n; j++)
{
out << (j ? "," : "") << _h_data[i + j * _m];
}
out << "\n";
}
return out;
}
/**
* Prints matrix in CSV to output stream
*/
std::ostream & write_matrix(std::ostream &out, int8_t)
{
for (int i = 0; i < _m; i++)
{
for (int j = 0; j < _n; j++)
{
out << (j ? "," : "") << int32_t(_h_data[i + j * _m]);
}
out << "\n";
}
return out;
}
/**
* Prints matrix in CSV to output stream
*/
std::ostream & write_matrix(std::ostream &out)
{
return write_matrix(out, _h_data[0]);
}
//-----------------------------------------------------------------------------
// Floating point "almost-equal" utilities
//-----------------------------------------------------------------------------
static bool almost_equal_ulps(half_t a, half_t b, int max_ulps)
{
if (a == b)
return true;
int32_t int_diff = abs(a.raw() - b.raw());
if (int_diff <= max_ulps)
return true;
return false;
}
static bool almost_equal_ulps(float a, float b, int max_ulps)
{
if (a == b)
return true;
int32_t int_diff = abs(*(int32_t*)&a - *(int32_t*)&b);
if (int_diff <= max_ulps)
return true;
return false;
}
static bool almost_equal_ulps(double a, double b, int max_ulps)
{
if (a == b)
return true;
int64_t int_diff = abs(*(int64_t*)&a - *(int64_t*)&b);
if (int_diff <= max_ulps)
return true;
return false;
}
static bool almost_equal_ulps(int32_t a, int32_t b, int max_ulps)
{
return (a == b);
}
//-----------------------------------------------------------------------------
// matrix operations
//-----------------------------------------------------------------------------
/**
* Returns matrix equality
*/
bool operator==(const matrix<value_t> &mat) const
{
int max_ulps = 30;
if (_m != mat._m || _n != mat._n)
{
fprintf(stderr, "Error: dimension mismatch during matrix comparison.\n"); exit(1);
}
for (int j = 0; j < _n; j++)
{
for (int i = 0; i < _m; i++)
{
if (!almost_equal_ulps(_h_data[i + j * _m], mat._h_data[i + j * _m], max_ulps))
{
return false;
}
}
}
return true;
}
/**
* Returns matrix inequality
*/
bool operator!=(const matrix<value_t> &mat) const
{
return !(*this == mat);
}
/**
* Computes this = (alpha * op(A) * op(B)) + (beta * this), specialized for gemm_nn
*/
template <typename multiplicand_t>
void gemm(
matrix_transform_t transform_a,
matrix_transform_t transform_b,
host_value_t alpha,
const matrix<multiplicand_t> &A,
const matrix<multiplicand_t> &B,
host_value_t beta)
{
// Sanity check dimensions
if ((_m != A.height(transform_a)) ||
(_n != B.width(transform_b)) ||
(A.width(transform_a) != B.height(transform_b)))
{
fprintf(stderr, "Error: dimension mismatch during gemm.\n");
exit(1);
}
int M = A.height(transform_a);
int K = A.width(transform_a);
int N = B.width(transform_b);
// Even the host-side implementation utilizes a blocking structure to improve
// verification performance
int DimBlockM = (M % 16 == 0) ? 16 : 1;
int DimBlockN = (N % 16 == 0) ? 16 : 1;
for (int i = 0; i < M; i += DimBlockM)
{
for (int j = 0; j < N; j += DimBlockN)
{
for (int block_y = 0; block_y < DimBlockM; block_y++)
{
for (int block_x = 0; block_x < DimBlockN; block_x++)
{
int y = i + block_y;
int x = j + block_x;
host_value_t accum(0);
for (int k = 0; k < K; k++)
{
accum += host_value_t(A.get(k, y, transform_a)) * host_value_t(B.get(x, k, transform_b));
}
_h_data[y + x * M] = (alpha * accum) + (beta * _h_data[y + x * M]);
}
}
}
}
}
};
} // namespace cutlass

99
cutlass_test/util/timer.h Normal file
View File

@@ -0,0 +1,99 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* GPU kernel timer
*/
#include <cuda_runtime.h>
#include <cutlass/util/debug.h>
namespace cutlass {
/******************************************************************************
* gpu_timer
******************************************************************************/
/**
* GPU event-based timer
*/
struct gpu_timer
{
cudaEvent_t _start;
cudaEvent_t _stop;
gpu_timer()
{
CUDA_PERROR_EXIT(cudaEventCreate(&_start));
CUDA_PERROR_EXIT(cudaEventCreate(&_stop));
}
~gpu_timer()
{
CUDA_PERROR_EXIT(cudaEventDestroy(_start));
CUDA_PERROR_EXIT(cudaEventDestroy(_stop));
}
void start()
{
CUDA_PERROR_EXIT(cudaEventRecord(_start, 0));
}
void stop()
{
CUDA_PERROR_EXIT(cudaEventRecord(_stop, 0));
}
float elapsed_millis()
{
float elapsed = 0.0;
CUDA_PERROR_EXIT(cudaEventSynchronize(_stop));
CUDA_PERROR_EXIT(cudaEventElapsedTime(&elapsed, _start, _stop));
return elapsed;
}
};
/******************************************************************************
* sleep_millis
******************************************************************************/
#ifdef _WIN32
#include <windows.h>
void sleep_millis(unsigned milliseconds)
{
Sleep(milliseconds);
}
#else
#include <unistd.h>
void sleep_millis(unsigned milliseconds)
{
usleep(milliseconds * 1000); // takes microseconds
}
#endif
} // namespace cutlass

View File

@@ -0,0 +1,155 @@
/******************************************************************************
* Copyright (c) 2011-2017, NVIDIA CORPORATION. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are not permitted.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
* (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
* LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
* ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
* SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
******************************************************************************/
#pragma once
/**
* \file
* \brief Utilities for converting between types and assessing traits
*/
#include "half.h"
namespace cutlass {
/******************************************************************************
* Float conversion utilities
******************************************************************************/
/// Convert float to value type
template <typename value_t>
value_t from_float(float val)
{
return value_t(val);
}
/// Convert float to value type (__half specialization)
template <>
__half from_float<__half>(float val)
{
return half_t(val);
}
/******************************************************************************
* Type conversion utilities
******************************************************************************/
/// Member \p type is defined as the signed integer type having the same size as \p T
template <typename T>
struct integer_alias;
template <>
struct integer_alias<int8_t> {
using type = int8_t;
};
template <>
struct integer_alias<half_t> {
using type = int16_t;
};
template <>
struct integer_alias<__half> {
using type = int16_t;
};
template <>
struct integer_alias<float> {
using type = int32_t;
};
template <>
struct integer_alias<int> {
using type = int32_t;
};
template <>
struct integer_alias<double> {
using type = int64_t;
};
/******************************************************************************
* Type-info utilities
******************************************************************************/
/// Returns a string to prefix 'gemm' to construct CUBLAS-like kernel names
template <math_operation_class_t math_op, typename value_t, typename accum_t> char const *to_prefix_string();
template <> char const *to_prefix_string<math_operation_class_t::scalar, half_t, half_t>() {
return "H";
}
template <> char const *to_prefix_string<math_operation_class_t::scalar, __half, __half>() {
return "H";
}
template <> char const *to_prefix_string<math_operation_class_t::scalar, float, float>() {
return "S";
}
template <> char const *to_prefix_string<math_operation_class_t::matrix, __half, __half>() {
return "WmmaH";
}
template <> char const *to_prefix_string<math_operation_class_t::matrix, __half, float>() {
return "WmmaS";
}
template <> char const *to_prefix_string<math_operation_class_t::scalar, double, double>() {
return "D";
}
template <> char const *to_prefix_string<math_operation_class_t::scalar, int8_t, int32_t>() {
return "I";
}
/******************************************************************************
* Maps value_t to the minimum vector size used to load operand
******************************************************************************/
template <typename T>
struct operand_load_type;
template <>
struct operand_load_type<int8_t> { using type = int32_t; };
template <typename T>
struct operand_load_type { using type = T; };
/******************************************************************************
* Minimum alignment requirement, if any, determined from value_t.
******************************************************************************/
template <typename value_t>
struct gemm_alignment_requirement;
template <>
struct gemm_alignment_requirement<uint8_t> { static const int value = 4; };
template <typename value_t>
struct gemm_alignment_requirement { static const int value = 0; };
} // namespace cutlass