From 374882be53f2a7558aeb6c4955b8b9da75b29ecf Mon Sep 17 00:00:00 2001 From: akerr Date: Mon, 11 Jun 2018 11:47:15 -0700 Subject: [PATCH] Replaced GoogleTest copy with submodule. Added updates to support intra-threadblock reductions. Added tests for same. --- cutlass/cutlass.h | 2 +- cutlass/fragment.h | 4 +- cutlass/fragment_multiply_add.h | 68 ++-- cutlass/gemm/clear_accumulators.h | 2 + cutlass/gemm/gemm.h | 139 ++++---- cutlass/gemm/gemm_epilogue.h | 20 +- cutlass/gemm/gemm_global_stream.h | 11 +- cutlass/gemm/gemm_global_tile.h | 161 ++++++--- cutlass/gemm/gemm_shared_tile.h | 113 ++++--- cutlass/gemm/gemm_traits.h | 125 +++++-- cutlass/gemm/hgemm_traits.h | 20 +- cutlass/gemm/igemm_global_tile.h | 92 ++++- cutlass/gemm/igemm_traits.h | 182 +++++++++- cutlass/gemm/linear_scaling.h | 11 +- cutlass/gemm/wmma_gemm_global_tile.h | 22 +- cutlass/iterator_access.h | 31 +- cutlass/load_store.h | 23 ++ cutlass/shape.h | 8 +- cutlass/tile_iterator.h | 20 +- cutlass/util/platform.h | 2 +- tools/test/unit/CMakeLists.txt | 8 + tools/test/unit/cutlass_unit_test.cpp | 17 + tools/test/unit/gemm/dgemm.cu | 233 +++++++++++++ tools/test/unit/gemm/gemm.h | 28 ++ tools/test/unit/gemm/gemm_testbed.h | 49 +-- tools/test/unit/gemm/hgemm_128x128x16.cu | 347 +++++++++++++++++++ tools/test/unit/gemm/hgemm_128x128x8.cu | 6 +- tools/test/unit/gemm/igemm_128x128x32.cu | 1 - tools/test/unit/gemm/igemm_32x32x128.cu | 238 +++++++++++++ tools/test/unit/gemm/sgemm_128x128x16.cu | 410 +++++++++++++++++++++++ tools/test/unit/gemm/sgemm_128x128x8.cu | 2 +- tools/test/unit/gemm/sgemm_128x32x16.cu | 294 ++++++++++++++++ tools/test/unit/gemm/sgemm_128x64x16.cu | 285 ++++++++++++++++ tools/test/unit/gemm/sgemm_128x64x8.cu | 6 +- tools/test/unit/gemm/sgemm_64x128x16.cu | 43 +++ tools/test/unit/gemm/sgemm_64x128x8.cu | 2 +- tools/test/unit/gemm/sgemm_64x32x16.cu | 277 +++++++++++++++ tools/test/unit/gemm/sgemm_64x64x16.cu | 294 ++++++++++++++++ tools/test/unit/gemm/wmma_gemm.cu | 16 + tools/util/host_tensor.h | 3 + 40 files changed, 3279 insertions(+), 336 deletions(-) create mode 100644 tools/test/unit/gemm/hgemm_128x128x16.cu create mode 100644 tools/test/unit/gemm/igemm_32x32x128.cu create mode 100644 tools/test/unit/gemm/sgemm_128x128x16.cu create mode 100644 tools/test/unit/gemm/sgemm_128x32x16.cu create mode 100644 tools/test/unit/gemm/sgemm_128x64x16.cu create mode 100644 tools/test/unit/gemm/sgemm_64x128x16.cu create mode 100644 tools/test/unit/gemm/sgemm_64x32x16.cu create mode 100644 tools/test/unit/gemm/sgemm_64x64x16.cu diff --git a/cutlass/cutlass.h b/cutlass/cutlass.h index 1e428b166..19600ec8f 100644 --- a/cutlass/cutlass.h +++ b/cutlass/cutlass.h @@ -33,7 +33,7 @@ #define CUTLASS_MAJOR 1 #define CUTLASS_MINOR 0 -#define CUTLASS_PATCH 0 +#define CUTLASS_PATCH 1 #define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH) #ifdef __NVCC__ diff --git a/cutlass/fragment.h b/cutlass/fragment.h index 53fa380c2..886b11405 100644 --- a/cutlass/fragment.h +++ b/cutlass/fragment.h @@ -184,7 +184,7 @@ struct FragmentIterator { /// The shape of the the fragment. typedef typename ShapeMul >::Shape FragmentShape; /// The linear strides for iterations. - typedef typename ShapeStrides::Shape Strides; + typedef typename ShapeStrides::Shape Strides; /// Ctor. template @@ -242,7 +242,7 @@ struct FragmentConstIterator { /// The shape of the the fragment. typedef typename ShapeMul >::Shape FragmentShape; /// The linear strides for iterations. - typedef typename ShapeStrides::Shape IterationsStrides; + typedef typename ShapeStrides::Shape IterationsStrides; /// Ctor. template diff --git a/cutlass/fragment_multiply_add.h b/cutlass/fragment_multiply_add.h index 2d31e793b..36a4d6f6a 100644 --- a/cutlass/fragment_multiply_add.h +++ b/cutlass/fragment_multiply_add.h @@ -49,21 +49,29 @@ struct FragmentMultiplyAdd { CUTLASS_DEVICE FragmentMultiplyAdd() {} /// Multiply : d = a*b. - template - CUTLASS_DEVICE void multiply(Scalar_ a, Fragment_ const& b, Fragment_& d) { - for (int j = 0; j < Fragment_::kElements; ++j) { - d[j] = a * b[j]; + template + CUTLASS_DEVICE void multiply(Scalar_ a, FragmentB_ const& b, FragmentCd_& d) { + int const kReduction = FragmentB_::kElements / FragmentCd_::kElements; + for (int j = 0; j < FragmentCd_::kElements; ++j) { + d[j] = a * b[j * kReduction + 0]; + for (int k = 1; k < kReduction; ++k) { + d[j] += a * b[j * kReduction + k]; + } } } /// Multiply : d = a*b + c. - template + template CUTLASS_DEVICE void multiply_add(Scalar_ a, - Fragment_ const& b, - Fragment_ const& c, - Fragment_& d) { - for (int j = 0; j < Fragment_::kElements; ++j) { - d[j] = a * b[j] + c[j]; + FragmentB_ const& b, + FragmentCd_ const& c, + FragmentCd_& d) { + int const kReduction = FragmentB_::kElements / FragmentCd_::kElements; + for (int j = 0; j < FragmentCd_::kElements; ++j) { + d[j] = a * b[j * kReduction + 0] + c[j]; + for (int k = 1; k < kReduction; ++k) { + d[j] += a * b[j * kReduction + k]; + } } } }; @@ -74,7 +82,7 @@ struct FragmentMultiplyAdd { template <> struct FragmentMultiplyAdd { /// The shape of the instruction. - typedef Shape<1, 1, 1, 1> InstructionShape; + typedef Shape<1, 1, 2, 1> InstructionShape; /// The type for A. typedef half ScalarA; /// The type for B. @@ -86,38 +94,48 @@ struct FragmentMultiplyAdd { CUTLASS_DEVICE FragmentMultiplyAdd() {} /// Multiply : d = a*b. - template - CUTLASS_DEVICE void multiply(half a, Fragment_ const& b, Fragment_& d) { + template + CUTLASS_DEVICE void multiply(half a, FragmentB_ const& b, FragmentCd_& d) { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 + + // Assemble a half2 from a. + __half2 const a_half2 = __half2half2(a); // The input. __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]); // The output. __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]); - // Assemble a half2 from a. - __half2 const a_half2 = __half2half2(a); - - for (int i = 0; i < Fragment_::kElements / 2; ++i) { - d_half2[i] = __hmul2(a_half2, b_half2[i]); + int const kReduction = FragmentB_::kElements / FragmentCd_::kElements; + for (int j = 0; j < FragmentCd_::kElements / 2; ++j) { + d_half2[j] = __hmul2(a_half2, b_half2[j * kReduction + 0]); + for (int k = 1; k < kReduction; ++k) { + d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]); + } } #endif } /// Multiply : d = a*b + c. - template - CUTLASS_DEVICE void multiply_add(half a, Fragment_ const& b, Fragment_ const& c, Fragment_& d) { + template + CUTLASS_DEVICE void multiply_add(half a, + FragmentB_ const& b, + FragmentCd_ const& c, + FragmentCd_& d) { #if defined(__CUDACC__) && __CUDA_ARCH__ >= 530 + // Assemble a half2 from a. + __half2 const a_half2 = __half2half2(a); // The inputs. __half2 const* b_half2 = reinterpret_cast<__half2 const*>(&b[0]); __half2 const* c_half2 = reinterpret_cast<__half2 const*>(&c[0]); // The output. __half2* d_half2 = reinterpret_cast<__half2*>(&d[0]); - // Assemble a half2 from a. - __half2 const a_half2 = __half2half2(a); - - for (int i = 0; i < Fragment_::kElements / 2; ++i) { - d_half2[i] = __hfma2(a_half2, b_half2[i], c_half2[i]); + int const kReduction = (FragmentB_::kElements / FragmentCd_::kElements); + for (int j = 0; j < FragmentCd_::kElements / 2; ++j) { + d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + 0], c_half2[j]); + for (int k = 1; k < kReduction; ++k) { + d_half2[j] = __hfma2(a_half2, b_half2[j * kReduction + k], d_half2[j]); + } } #endif } diff --git a/cutlass/gemm/clear_accumulators.h b/cutlass/gemm/clear_accumulators.h index 12e1f5790..441370f4c 100644 --- a/cutlass/gemm/clear_accumulators.h +++ b/cutlass/gemm/clear_accumulators.h @@ -39,6 +39,8 @@ struct ClearAccumulators { /// The shared storage. struct SharedStorage {}; + /// Ctor. + CUTLASS_DEVICE ClearAccumulators() {} /// Ctor. CUTLASS_DEVICE ClearAccumulators(SharedStorage& shared_storage) {} diff --git a/cutlass/gemm/gemm.h b/cutlass/gemm/gemm.h index 0ca093ff5..c50a3f04b 100644 --- a/cutlass/gemm/gemm.h +++ b/cutlass/gemm/gemm.h @@ -40,7 +40,7 @@ namespace gemm { //////////////////////////////////////////////////////////////////////////////////////////////////// template -__global__ void gemm_kernel(typename Gemm_::Params params) { +__global__ /*__launch_bounds__(Gemm_::kThreads)*/ void gemm_kernel(typename Gemm_::Params params) { // Declare shared memory. __shared__ typename Gemm_::SharedStorage shared_storage; @@ -193,6 +193,71 @@ struct Gemm { CUTLASS_DEVICE Gemm(Params const& params_, SharedStorage& shared_storage_) : params(params_), shared_storage(shared_storage_) {} + /// Consume a single iteration of the loop. + template + CUTLASS_DEVICE void consume_tile(typename Traits::GlobalLoadStream& global_stream, + typename Traits::SharedLoadStream& shared_load_stream, + typename Traits::MultiplyAdd::Accumulators& accumulators, + Index outer_k) { + // If that's the last "load iteration" update the predicates. + if (!kIsLastIteration) { + global_stream.move_to_residue(outer_k); + } + + // Load data for the next iteration of the main loop. + if (!kIsLastIteration) { + global_stream.copy(); + } + + // The unrolling steps for the main loop. + int const kUnrollingSteps = + Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD; + + CUTLASS_PRAGMA_UNROLL + for (int step = 0; step < kUnrollingSteps - 1; ++step) { + // Trigger the copy from shared memory for the next A/B values. + shared_load_stream.copy(step + 1); + // Make sure the values are available for the current iteration to do the multiply-add. + shared_load_stream.commit(step); + + // Do the math on the fragments of the current iteration. + typename Traits::MultiplyAdd multiply_add; + multiply_add.multiply_add(shared_load_stream.fragment_a(step), + shared_load_stream.fragment_b(step), + accumulators, + accumulators); + } + + // Make sure the data from shared memory has been entirely consumed. + Traits::shared_load_fence(true); + + // Commit the data in shared memory for A/B. + if (!kIsLastIteration) { + global_stream.commit(); + } + + // Make sure the data is in shared memory. + Traits::shared_store_fence(true); + + // Trigger the loads for the next iteration (if needed). + if (!kIsLastIteration) { + // Move to the next stage for the load (if it makes sense). + shared_load_stream.inc_stage(); + // Trigger the copy from shared memory for the next loop iteration. + shared_load_stream.copy(0); + } + + // Make sure the values are available for the current iteration to do the multiply-add. + shared_load_stream.commit(kUnrollingSteps - 1); + + // Do the math on the fragments of the current iteration. + typename Traits::MultiplyAdd multiply_add; + multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1), + shared_load_stream.fragment_b(kUnrollingSteps - 1), + accumulators, + accumulators); + } + /// Do the GEMM. CUTLASS_DEVICE void multiply_add() { // Swizzle the IDs of the block (to enable better cache behavior). @@ -212,16 +277,11 @@ struct Gemm { // Create the accumulator clear. ClearAccumulators clear(shared_storage.main_loop.clear); - /// Define the mainloop iteration size - typedef typename Traits::MultiplyAdd MultiplyAdd; - // By how much we unroll the main loop. - Index const kUnroll = static_cast(MultiplyAdd::AccumulatorsPerWarp::kD); + Index const kUnroll = static_cast(Traits::OutputTile::kD); // If we do not have enough steps in the main loop, trigger the residue code. - if (params.k < kUnroll) { - global_stream.residue(params.k, true); - } + global_stream.move_to_residue(params.k); // Fetch the fragments for A and B from global memory. global_stream.copy(); @@ -232,9 +292,12 @@ struct Gemm { // Make sure the data is in shared memory. Traits::shared_store_fence(false); + // Rollback to the beginning of the GEMM-K dimension. It may have no impact. + global_stream.rollback(); + // The unrolling steps for the main loop. int const kUnrollingSteps = - MultiplyAdd::AccumulatorsPerWarp::kD / MultiplyAdd::InstructionShape::kD; + Traits::MultiplyAdd::AccumulatorsPerWarp::kD / Traits::MultiplyAdd::InstructionShape::kD; // Make sure we have at least 2 unrolling steps or our pipeling is not going to work. static_assert(kUnrollingSteps >= 2, "The pipelining assumes at least two steps"); @@ -246,59 +309,21 @@ struct Gemm { shared_load_stream.copy(0); // Allocate the accumulators. - typename MultiplyAdd::Accumulators accumulators; + typename Traits::MultiplyAdd::Accumulators accumulators; // Clear the accumulators. clear.clear(accumulators); + // The loop index. + Index outer_k = params.k - kUnroll; + // Enter the main loop and iterate. - typedef typename Traits::Index Index; - for (Index outer_k = params.k - kUnroll; outer_k > -kUnroll; outer_k -= kUnroll) { - // If that's the last "load iteration" update the predicates. - int const is_residue = outer_k <= kUnroll; - if (is_residue) { - global_stream.residue(outer_k); - } + for (; outer_k > 0; outer_k -= kUnroll) { + consume_tile(global_stream, shared_load_stream, accumulators, outer_k); + } - // Load data for the next iteration of the main loop. - global_stream.copy(); - - CUTLASS_PRAGMA_UNROLL - for (int step = 0; step < kUnrollingSteps - 1; ++step) { - // Trigger the copy from shared memory for the next A/B values. - shared_load_stream.copy(step + 1); - // Make sure the values are available for the current iteration to do the multiply-add. - shared_load_stream.commit(step); - - // Do the math on the fragments of the current iteration. - MultiplyAdd multiply_add; - multiply_add.multiply_add(shared_load_stream.fragment_a(step), - shared_load_stream.fragment_b(step), - accumulators, - accumulators); - } - - // Make sure the data from shared memory has been entirely consumed. - Traits::shared_load_fence(true); - - // Commit the data in shared memory for A/B. - global_stream.commit(); - - // Make sure the data is in shared memory. - Traits::shared_store_fence(true); - - // Move to the next stage for the load (if it makes sense). - shared_load_stream.inc_stage(); - // Trigger the copy from shared memory for the next loop iteration. - shared_load_stream.copy(0); - // Make sure the values are available for the current iteration to do the multiply-add. - shared_load_stream.commit(kUnrollingSteps - 1); - - // Do the math on the fragments of the current iteration. - MultiplyAdd multiply_add; - multiply_add.multiply_add(shared_load_stream.fragment_a(kUnrollingSteps - 1), - shared_load_stream.fragment_b(kUnrollingSteps - 1), - accumulators, - accumulators); + // Residual loop. + for (; outer_k > -kUnroll; outer_k -= kUnroll) { + consume_tile(global_stream, shared_load_stream, accumulators, outer_k); } // Epilogue. diff --git a/cutlass/gemm/gemm_epilogue.h b/cutlass/gemm/gemm_epilogue.h index de6513a40..bc2530777 100644 --- a/cutlass/gemm/gemm_epilogue.h +++ b/cutlass/gemm/gemm_epilogue.h @@ -117,6 +117,7 @@ struct GemmEpilogue { CUTLASS_DEVICE void epilogue_with_or_without_beta(Coord<3> const& block, Accumulators& accumulators) { + // The problem size. Coord<3> const bounds = cutlass::make_Coord(0, n, m); // The functor. @@ -153,6 +154,18 @@ struct GemmEpilogue { GlobalStoreIteratorD global_store_iterator( params.iterator_d, bounds, block, pointer_offset, predicate_offset); + // The transformer to transform before storing to shared memory. + SharedStoreTransformerD shared_store_transformer; + typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d; + + // The iterator to store to shared memory. + SharedStoreIteratorD shared_store_iterator(params.shared_store_iterator_d, + shared_storage.shared_stream.store); + + // The iterator to load from shared memory. TODO: Use a stream. + SharedLoadIteratorD shared_load_iterator(params.shared_load_iterator_d, + shared_storage.shared_stream.load); + CUTLASS_PRAGMA_UNROLL for (int w = 0; w < Iterations::kW; ++w) { // Load the C matrix into fragment. @@ -166,20 +179,13 @@ struct GemmEpilogue { // Copy the accumulators to shared memory. int const offset = (h * Iterations::kW + w) * SharedStoreIteratorD::Fragment::kElements; - SharedStoreTransformerD shared_store_transformer; - typename SharedStoreTransformerD::OutputFragment shared_store_transformed_d; shared_store_transformer.transform(accumulators, offset, shared_store_transformed_d); - - SharedStoreIteratorD shared_store_iterator(params.shared_store_iterator_d, - shared_storage.shared_stream.store); shared_iterator_store(shared_store_iterator, shared_store_transformed_d); // Make sure the data is in shared memory. shared_store_fence(); // Copy the accumulators back to registers from shared memory. - SharedLoadIteratorD shared_load_iterator(params.shared_load_iterator_d, - shared_storage.shared_stream.load); typename SharedLoadIteratorD::Fragment fetched_d; shared_iterator_load(shared_load_iterator, fetched_d); diff --git a/cutlass/gemm/gemm_global_stream.h b/cutlass/gemm/gemm_global_stream.h index 194f0decf..ec675a38f 100644 --- a/cutlass/gemm/gemm_global_stream.h +++ b/cutlass/gemm/gemm_global_stream.h @@ -84,8 +84,9 @@ struct GlobalLoadStreamBase { typename StoreIterator::Params store_iterator; /// Setup the params. - CUTLASS_HOST_DEVICE int initialize(Pointer pointer, Index ld) { - int error_code = load_iterator.initialize(pointer, ld); + template + CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Pointer pointer, Index ld) { + int error_code = load_iterator.initialize(desc, pointer, ld); if (error_code) { return error_code; } @@ -128,6 +129,9 @@ struct GlobalLoadStreamBase { store_iterator.inc_stage(); } + /// Move to the beginning of the residue code. That's a new code path in CUTLASS 1.0.1. + CUTLASS_DEVICE void move_to_residue(Index k) { load_iterator.move_to_residue(k); } + /// Execute the residue code. CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) { load_iterator.residue(k); @@ -136,6 +140,9 @@ struct GlobalLoadStreamBase { } } + /// Rollback to the beginning of the GEMM-k dimension. + CUTLASS_DEVICE void rollback() { load_iterator.rollback(); } + /// The iterator. LoadIterator load_iterator; /// The fragment to fetch from shared memory. diff --git a/cutlass/gemm/gemm_global_tile.h b/cutlass/gemm/gemm_global_tile.h index 28bcc6a98..1cc3b3377 100644 --- a/cutlass/gemm/gemm_global_tile.h +++ b/cutlass/gemm/gemm_global_tile.h @@ -195,7 +195,8 @@ struct GemmGlobalIteratorAb struct Params : public BaseParams { /// Initializes params to load a strip-mined tile, given pointer and stride_h. - CUTLASS_HOST_DEVICE int initialize(Scalar const* ptr, Index stride_h) { + template + CUTLASS_HOST_DEVICE int initialize(GemmDesc_ const& desc, Scalar const* ptr, Index stride_h) { Index inc_d = 0; Index inc_advance = 0; // Move by some columns for each iteration in the H dimension. @@ -220,16 +221,75 @@ struct GemmGlobalIteratorAb (Base::Iterations::kH - 1) * inc_h; } - Base::Params::initialize(ptr, 0, stride_h, 0, inc_d, inc_h, 0, inc_advance); + // The dimensions of the tile. + int const kH = TileTraits_::Tile::kH; + int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize; + + // Move to the residue. + Index const kBlock = kAdvance == IteratorAdvance::kH ? kH : kW; + // The jump in the gemm-k dimension. + Index const stride = kAdvance == IteratorAdvance::kH ? stride_h : 1; + + // Compute the offset to the residue and how to "come" back. + Index const kResidue = desc.k % kBlock; + if (kResidue > 0) { + move_to_residue_offset = (desc.k - kResidue) * stride; + } else { + move_to_residue_offset = (desc.k - kBlock) * stride; + } + + Base::Params::initialize(ptr, 0, stride_h, 1, inc_d, inc_h, 0, inc_advance); return 0; } + + // The extra offset to control moving to the residue. + Index move_to_residue_offset; }; - /// Offset of an individual lane from the start of the tile - Coord<4> thread_offset; - /// The parameters - Params params; + /// Ctor. + CUTLASS_DEVICE GemmGlobalIteratorAb(Params const& _params, + const Coord<3>& bounds, + const Coord<3>& block, + ThreadOffset thread_offset_func = ThreadOffset()) + : params(_params) { + thread_offset = thread_offset_func(); + // The column. + Index block_h = thread_offset[1]; + // The contiguous dimension. + Index block_w = thread_offset[2]; + // Add the blocks indices. + if (kAdvance == IteratorAdvance::kH) { + block_h += block[1]; + block_w += block[2]; + + } else { + block_h += block[2]; + block_w += block[1]; + } + + // Setup the pointer. + params.pointer += (block_h * params.stride_h + block_w); + + // Initialize predicates + initialize_predicates(bounds, make_Coord(0, block_h, block_w)); + } + + /// The accessor. + CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const { + int const imm = + ComputeOffsetFromStrides::get(0, 0, w, c); + Load::load(value, params.pointer, imm); + } + + /// Increment the pointer in the H dimension. + CUTLASS_DEVICE void inc_h() { params.pointer += params.inc_h; } + /// Increment the pointer in the D dimension. + CUTLASS_DEVICE void inc_d() { params.pointer += params.inc_d; } + /// Increment the pointer to move to the next iteration. + CUTLASS_DEVICE void inc_advance() { params.pointer += params.inc_advance; } + + /// Initialize the predicates. CUTLASS_DEVICE void initialize_predicates(const Coord<3>& bounds, const Coord<3>& block) { // Setup the masks to control loads. predicates.fill(0); @@ -263,46 +323,29 @@ struct GemmGlobalIteratorAb } } - /// Ctor. - CUTLASS_DEVICE GemmGlobalIteratorAb(Params const& _params, - const Coord<3>& bounds, - const Coord<3>& block, - ThreadOffset thread_offset_func = ThreadOffset()) - : params(_params) { - thread_offset = thread_offset_func(); - // The column. - Index block_h = thread_offset[1]; - // The contiguous dimension. - Index block_w = thread_offset[2]; + /// Move to residue portion. + CUTLASS_DEVICE void move_to_residue(Index k) { + // Store the pointer and the predicates. + stored_pointer = params.pointer; + stored_predicates = predicates; - // Add the blocks indices. - if (kAdvance == IteratorAdvance::kH) { - block_h += block[1]; - block_w += block[2]; + // Move the pointer to the residue. + params.pointer += params.move_to_residue_offset; - } else { - block_h += block[2]; - block_w += block[1]; + // The dimensions of the tile. + int const kH = TileTraits_::Tile::kH; + int const kW = TileTraits_::Tile::kW * TileTraits_::kAccessSize; + + // The unrolling factor. + int const kUnroll = kAdvance == IteratorAdvance::kH ? kH : kW; + + // Clear the predicates for the residue. TODO: We can do something smarter. + int const kResidue = (int)(k % (Index)kUnroll); + if (kResidue > 0) { + residue(kResidue); } - - // Setup the pointer. - params.pointer += (block_h * params.stride_h + block_w); - - // Initialize predicates - initialize_predicates(bounds, make_Coord(0, block_h, block_w)); } - /// Increment the pointer in the H dimension. - CUTLASS_DEVICE void inc_h() { params.pointer += params.inc_h; } - /// Increment the pointer in the D dimension. - CUTLASS_DEVICE void inc_d() { params.pointer += params.inc_d; } - /// Increment the pointer to move to the next iteration. - CUTLASS_DEVICE void inc_advance() { params.pointer += params.inc_advance; } - - /// Returns the current pointer - CUTLASS_HOST_DEVICE - Scalar const* data() const { return params.pointer; } - /// That's the residue! Update the predicates. CUTLASS_DEVICE void residue(Index k) { // The coordinates of the thread. @@ -332,14 +375,26 @@ struct GemmGlobalIteratorAb } } + /// Rollback to beginning of first tile and initialize predicates. + CUTLASS_DEVICE void rollback() { + params.pointer = stored_pointer; + predicates = stored_predicates; + } + /// Is the iterator valid? CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { int const bit = ComputeOffsetFromShape::get(d, h, w, c); return predicates[bit]; } + /// Offset of an individual lane from the start of the tile + Coord<4> thread_offset; + /// The parameters + Params params; + /// The pointer. + typename Base::Scalar const* stored_pointer; /// The predicates. - PredicateVector predicates; + PredicateVector predicates, stored_predicates; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -439,6 +494,13 @@ struct GemmGlobalIteratorCd : public TileIteratorBaseparams.predicate_offset -= (h + pred_offset); } + /// The accessor. + CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const { + int const imm = + ComputeOffsetFromStrides::get(0, 0, w, c); + Load::load(value, params.pointer, imm); + } + /// Increment the pointer in the C dimension. CUTLASS_DEVICE void inc_c() {} /// Increment the pointer in the W dimension. @@ -456,18 +518,19 @@ struct GemmGlobalIteratorCd : public TileIteratorBaseparams.predicate_offset -= params.predicate_inc_advance; } + /// The accessor. + CUTLASS_DEVICE void set(typename Base::AccessType const& value, int d, int h, int w, int c) { + int const imm = + ComputeOffsetFromStrides::get(0, 0, w, c); + Store::store( + value, params.pointer, imm); + } + /// Test the validity of the iterator. CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return predicates.at(w) && params.predicate_offset > 0; } - /// Returns the raw pointer - CUTLASS_HOST_DEVICE - Pointer data() { return params.pointer; } - - CUTLASS_HOST_DEVICE - Pointer const data() const { return params.pointer; } - /// The predicates for the row. cutlass::PredicateVector predicates; }; diff --git a/cutlass/gemm/gemm_shared_tile.h b/cutlass/gemm/gemm_shared_tile.h index 9ec4c9a27..7c61e0229 100644 --- a/cutlass/gemm/gemm_shared_tile.h +++ b/cutlass/gemm/gemm_shared_tile.h @@ -104,8 +104,7 @@ struct GemmSharedStoreWithSkewTileAbTraits { typedef Shape<0, ShapeCount::kWc, Threads::kH * kAccessSize> ImmediateOffsetStrides; struct ThreadOffset { - CUTLASS_HOST_DEVICE - Coord<4> operator()() const { + CUTLASS_HOST_DEVICE Coord<4> operator()() const { int offset = ComputeThreadOffsetFromStrides::get(); return make_Coord(0, 0, offset, 0); } @@ -164,22 +163,25 @@ struct GemmSharedLoadTileATraits { typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kScalarsPerLds*/> Iterations; /// The strides in each dimension between different loads/stores. - typedef Shape Delta; - /// The strides in each dimension between different loads/stores. - typedef Shape + typedef Shape ImmediateOffsetStrides; + typedef Shape Delta; /// Computes the thread offset in (H, W) based on thread ID struct ThreadOffset { - CUTLASS_HOST_DEVICE - Coord<4> operator()() const { + CUTLASS_HOST_DEVICE Coord<4> operator()() const { // Extract the warp. - int const warp = threadIdx.x / kWarpSize % Warps::kW; - // Compute the row offset for each thread - int const lane = (threadIdx.x & 0x0e) / 2; + int const warp = threadIdx.x / kWarpSize; + // Extract the slice. + int const slice = warp / (Warps::kH * Warps::kW); + // Compute the row offset for each warp. + int const warp_row = warp % Warps::kW; + // Compute the row offset for each thread. + int const lane_row = (threadIdx.x & 0x0e) / 2; // The offset. - int const offset = (warp * ThreadsPerWarp::kW + lane) * kAccessSize; - + int const offset = + slice * Tile::kW * Tile::kC + (warp_row * ThreadsPerWarp::kW + lane_row) * kAccessSize; + // Embed the offset in a 4D coordinate vector. return make_Coord(0, 0, offset, 0); } }; @@ -231,23 +233,27 @@ struct GemmSharedLoadTileBTraits { /// The number of iterations needed to load/store the tile. typedef Shape<1, 1, TileWithoutSkew::kW / kWarps / kThreadsPerWarp /* / kAccessSize*/> Iterations; /// The strides in each dimension between different loads/stores. - typedef Shape Delta; - /// The strides in each dimension between different loads/stores. - typedef Shape + typedef Shape ImmediateOffsetStrides; + typedef Shape Delta; /// Computes the thread offset in (H, W) based on thread ID struct ThreadOffset { - CUTLASS_HOST_DEVICE - Coord<4> operator()() const { - // The position of the warp. - int const warp = threadIdx.x / (Warps::kW * kWarpSize); - - // Compute the column offset for each thread - int const lane = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01); + CUTLASS_HOST_DEVICE Coord<4> operator()() const { + // Extract the warp. + int const warp = threadIdx.x / kWarpSize; + // Extract the slice. + int const slice = warp / (Warps::kH * Warps::kW); + // The warp in the slice. + int const warp_in_slice = warp % (Warps::kH * Warps::kW); + // Compute the row offset for each warp. + int const warp_col = warp_in_slice / Warps::kW; + // Compute the row offset for each thread. + int const lane_col = (threadIdx.x & 0x10) / 8 + (threadIdx.x & 0x01); // The offset. - int const offset = (warp * ThreadsPerWarp::kH + lane) * kAccessSize; - + int const offset = + slice * Tile::kW * Tile::kC + (warp_col * ThreadsPerWarp::kH + lane_col) * kAccessSize; + // Embed the offset in a 4D coordinate. return make_Coord(0, 0, offset, 0); } }; @@ -297,28 +303,26 @@ struct GemmSharedStoreTileDTraits { /// Computes the thread offset in (H, W) based on thread ID struct ThreadOffset { - CUTLASS_HOST_DEVICE - Coord<4> operator()() const { - // We issue STS.128 in the epilogue to store the accumulators to shared memory. When we use - // STS.128, we have to guarantee that threads in groups of 8 do not have bank conflicts (i.e - // they write to different banks). + CUTLASS_HOST_DEVICE Coord<4> operator()() const { + // The warp. + int const warp = threadIdx.x / kWarpSize; + + // The position of the warp in the 2D tile. + int const warp_row = warp % Warps::kW; + int const warp_col = warp / Warps::kW; + + // We assume that the elements are distributed in a warps as 4 columns of 8 elements. The + // columns are stored in threads col0=[0, 2, 4, 6, 8, 10, 12, 14], col1=[1, 3, 5, 7, .., 15], + // col2=[16, 18, 20, ..., 30] and col3=[17, 19, ..., 31]. + int hi_halfwarp_offset = ((threadIdx.x >> 4) & 0x1) * OutputTile::kW; + int lo_halfwarp_offset = ((threadIdx.x >> 1) & 0x7) + ThreadsPerWarp::kW * warp_row; // Odd threads go to the second half of shared memory. int const row = threadIdx.x & 0x01; - - int const warp_id = (threadIdx.x >> 5); - - int const warp_row = (warp_id % Warps::kW); - int const warp_col = (warp_id / Warps::kW); - - int hi_halfwarp_offset = OutputTile::kW * ((threadIdx.x >> 4) & 1); - int lo_halfwarp_offset = (((threadIdx.x >> 1) & 0x7) + warp_row * ThreadsPerWarp::kW); - - int col = kAccessSize * lo_halfwarp_offset + - warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW + hi_halfwarp_offset; - - int offset = row * kScalarsPerRow + col; - return make_Coord(0, 0, offset, 0); + int col = warp_col * (ThreadsPerWarp::kH / 2) * OutputTile::kW + + lo_halfwarp_offset * kAccessSize + hi_halfwarp_offset; + // Embed the offset in a 4D coords. + return make_Coord(0, 0, row * kScalarsPerRow + col, 0); } }; }; @@ -357,32 +361,39 @@ struct GemmSharedLoadTileDTraits { /// The number of scalars per row. We build a tile with 2 rows (to avoid bank conflicts). static int const kScalarsPerRow = kThreads / 2 * kScalarsPerThread + kSkew; - /// The tile. + /// The tile. We have 2 rows of scalars. We use those two rows to make sure we do not have bank + /// conflicts in the epilogue. typedef Shape<1, 2, kScalarsPerRow / kAccessSize, kAccessSize> Tile; // Compute the number of iterations per warp in the Tile::kH dimension. static int const kIterationsInHPerWarp = kTileH_ / ShapeCount::kCount; - // As shown above, the shared memory tile is composed of 2 rows and each rows is made of + // As explained above, the shared memory tile is composed of 2 rows and each rows is made of // kScalarsPerRow. A warp is expected to read from the 1st row, then move to the 2nd row and go // back to the 1st row. To model that scheme we define the Iterations shape as Shape. // However, in some cases, we have only 1 iteration per warp. In that case, we must define the - // shape as Shape<1, 1, ...>. The following code does that. + // shape as Shape<1, 1, ...>. The following code does that except that we hijack the kH dimension + // to keep the number of elements to reduce for split-K. static int const kIterationsH = kIterationsInHPerWarp == 1 ? 1 : 2; // As soon as we know kIterationsH, it is trivial to compute kIterationsD: static int const kIterationsD = kIterationsInHPerWarp / kIterationsH; + // If we have split-K enabled, we have to jump over the elements from the "odd/even" column of + // threads to grab the other elements. + static int const kSplitK = OutputTile::kW * ThreadsPerWarp::kH / 2 * Warps::kH; + /// The number of iterations needed to store the tile. - typedef Shape Iterations; + typedef Shape + Iterations; /// The strides in each dimension between different loads/stores. - typedef Shape Delta; + typedef Shape + ImmediateOffsetStrides; /// The strides in each dimension between different loads/stores. - typedef Shape ImmediateOffsetStrides; + typedef Shape Delta; /// Computes the thread offset in (H, W) based on thread ID struct ThreadOffset { - CUTLASS_HOST_DEVICE - Coord<4> operator()() const { + CUTLASS_HOST_DEVICE Coord<4> operator()() const { // Each warp works on a different column. int const h = threadIdx.x / kWarpSize; // Compute the row. diff --git a/cutlass/gemm/gemm_traits.h b/cutlass/gemm/gemm_traits.h index 7a77d4b0d..cb57c4d5c 100644 --- a/cutlass/gemm/gemm_traits.h +++ b/cutlass/gemm/gemm_traits.h @@ -74,7 +74,9 @@ template < /// The number of scalars per LDS for D. int kScalarsPerLdsD_, /// The number of stages in shared memory to do single/double/triple-buffering. - int kStages_> + int kStages_, + /// Do we do the residue in the prologue? + bool kResidueInPrologue_ = false> struct GemmConfig { // @@ -129,6 +131,9 @@ struct GemmConfig { /// The number of stages in shared memory to implement double, triple, more-buffering. static int const kStages = kStages_; + + /// Do we do the residue in the prologue? + static bool const kResidueInPrologue = kResidueInPrologue_; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -229,8 +234,12 @@ struct GemmTileTraitsHelperA { /// The number of scalars in 4B. static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar); + /// The skew for A. + static int const kSkewA = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA / + GlobalTileTraits::Threads::kW * kScalarsIn4B; + /// The traits class to build the iterator to store data to shared memory for A^T. - typedef GemmSharedStoreWithSkewTileAbTraits< + typedef GemmSharedStoreWithSkewTileAbTraits < // The pointer is float. MultiplyAddScalar, // The tile has size KxM in GEMM's terminology. @@ -242,9 +251,8 @@ struct GemmTileTraitsHelperA { // The number of scalars per STS. GemmConfig_::kScalarsPerStsA, // The skew to avoid bank conflicts added in the tile W dimension. - 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsA / - GlobalTileTraits::Threads::kW * kScalarsIn4B> - SharedStoreTileTraits; + kSkewA + SharedStoreTileTraits; /// The traits class to build the iterator to load from shared memory for A^T. typedef GemmSharedLoadTileATraits< @@ -302,8 +310,12 @@ struct GemmTileTraitsHelperB { /// The number of scalars in 4B. static int const kScalarsIn4B = sizeof(MultiplyAddScalar) > 4 ? 1 : 4 / sizeof(MultiplyAddScalar); + /// The skew for B. + static int const kSkewB = 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB / + GlobalTileTraits::Threads::kW * kScalarsIn4B; + /// The traits class to build the iterator to store data to shared memory for B^N. - typedef GemmSharedStoreWithSkewTileAbTraits< + typedef GemmSharedStoreWithSkewTileAbTraits < // The pointer is float. MultiplyAddScalar, // The tile has size KxN in GEMM's terminology. @@ -315,9 +327,8 @@ struct GemmTileTraitsHelperB { // The number of scalars per STS. GemmConfig_::kScalarsPerStsB, // The skew to avoid bank conflicts added in the tile W dimension. - 128 / sizeof(MultiplyAddScalar) / GemmConfig_::kScalarsPerStsB / - GlobalTileTraits::Threads::kW * kScalarsIn4B> - SharedStoreTileTraits; + kSkewB + SharedStoreTileTraits; /// The traits class to build the iterator to load from shared memory for B^N. typedef GemmSharedLoadTileBTraits< @@ -405,6 +416,60 @@ struct GemmTileTraitsHelperB { //////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct GemmResidue { + /// Move to residue portion. + template + static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a, + typename GemmTraits_::GlobalLoadStreamB& stream_b, + typename GemmTraits_::Index k) { + // The new code path in CUTLASS 1.0.1: We treat the residue in the prologue so we can have + // complete main loops after that. It helps simplify the logic in the main loop. + if (kIsPrologue) { + stream_a.move_to_residue(k); + stream_b.move_to_residue(k); + } + } + + /// Rollback to beginning of first tile and initialize predicates. + static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a, + typename GemmTraits_::GlobalLoadStreamB& stream_b) { + stream_a.rollback(); + stream_b.rollback(); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct GemmResidue { + /// Move to residue portion. + template + static CUTLASS_DEVICE void move_to_residue(typename GemmTraits_::GlobalLoadStreamA& stream_a, + typename GemmTraits_::GlobalLoadStreamB& stream_b, + typename GemmTraits_::Index k) { + // The index. + typedef typename GemmTraits_::Index Index; + // By how much we unroll the main loop. + Index const kUnroll = static_cast(GemmTraits_::OutputTile::kD); + + // Call the residue code. That's the same path as CUTLASS 1.0.0. + if (kIsPrologue && k < kUnroll) { + stream_a.residue(k, true); + stream_b.residue(k, true); + } else if (k <= kUnroll) { + stream_a.residue(k, false); + stream_b.residue(k, false); + } + } + + /// Rollback to beginning of first tile and initialize predicates. + static CUTLASS_DEVICE void rollback(typename GemmTraits_::GlobalLoadStreamA& stream_a, + typename GemmTraits_::GlobalLoadStreamB& stream_b) {} +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template < /// The GEMM configuration. typename GemmConfig_, @@ -426,10 +491,24 @@ template < typename ClearAccumulators_ = ClearAccumulators > struct GemmTraits { + /// This class. + typedef GemmTraits + This_; + /// The configuration. typedef GemmConfig_ GemmConfig; /// The output tile. typedef typename GemmConfig::OutputTile OutputTile; + /// Is the residue treated in the prologue? + static bool const kResidueInPrologue = GemmConfig::kResidueInPrologue; /// The stream to load A from global memory to shared memory. typedef GlobalLoadStreamA_ GlobalLoadStreamA; @@ -450,18 +529,6 @@ struct GemmTraits { /// The iterator for B to load from shared memory. typedef SharedLoadStreamB_ SharedLoadStreamB; - /// The shared storage for A. - typedef typename GlobalLoadStreamA::SharedStoreStorage SharedStoreStorageA; - // Btw, make sure we did not messed up with the size of the storage. - static_assert(sizeof(SharedStoreStorageA) == sizeof(typename SharedLoadStreamA::SharedStorage), - ""); - - /// The shared storage for B. - typedef typename GlobalLoadStreamB::SharedStoreStorage SharedStoreStorageB; - // Btw, make sure we did not messed up with the size of the storage. - static_assert(sizeof(SharedStoreStorageB) == sizeof(typename SharedLoadStreamB::SharedStorage), - ""); - /// The multiply-add functor. typedef typename GemmConfig::MultiplyAdd MultiplyAdd; /// The epilogue. @@ -502,14 +569,15 @@ struct GemmTraits { // Initialize the iterator for A. int error_code = - global_stream_a.initialize(reinterpret_cast(desc.d_a), desc.lda); + global_stream_a.initialize(desc, reinterpret_cast(desc.d_a), desc.lda); if (error_code) { return error_code; } // Initialize the iterator for B. - error_code = global_stream_b.initialize(reinterpret_cast(desc.d_b), desc.ldb); + error_code = + global_stream_b.initialize(desc, reinterpret_cast(desc.d_b), desc.ldb); if (error_code) { return error_code; @@ -574,12 +642,15 @@ struct GemmTraits { stream_b.commit(); } - /// Execute the residue code. - CUTLASS_DEVICE void residue(Index k, bool skip_clear = false) { - stream_a.residue(k, skip_clear); - stream_b.residue(k, skip_clear); + /// Move to residue portion. + template + CUTLASS_DEVICE void move_to_residue(Index k) { + GemmResidue::move_to_residue(stream_a, stream_b, k); } + /// Rollback to beginning of first tile and initialize predicates. + CUTLASS_DEVICE void rollback() { GemmResidue::rollback(stream_a, stream_b); } + /// The stream for A. GlobalLoadStreamA stream_a; /// The stream for B. diff --git a/cutlass/gemm/hgemm_traits.h b/cutlass/gemm/hgemm_traits.h index 78e5bac5b..b08645bf4 100644 --- a/cutlass/gemm/hgemm_traits.h +++ b/cutlass/gemm/hgemm_traits.h @@ -147,8 +147,11 @@ struct HgemmTileTraitsHelperA GemmConfig_::kScalarsPerLdgA> GlobalTileTraits; + /// The skew. + static int const kSkewA = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2; + /// The traits class to build the iterator to store data to shared memory for A^T. - typedef GemmSharedStoreWithSkewTileAbTraits< + typedef GemmSharedStoreWithSkewTileAbTraits < // The pointer. half, // The tile has size KxM in GEMM's terminology. @@ -160,8 +163,8 @@ struct HgemmTileTraitsHelperA // The number of scalars per STS (STS.32 or STS.128, etc). 2, // The skew to avoid bank conflicts added in the tile W dimension. - 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2> - SharedStoreTileTraits; + kSkewA + SharedStoreTileTraits; /// The traits class to build the iterator to load from shared memory for A^T. typedef GemmSharedLoadTileATraits< @@ -212,8 +215,11 @@ struct HgemmTileTraitsHelperB GemmConfig_::kScalarsPerLdgB> GlobalTileTraits; + /// The skew for B. + static int const kSkewB = 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2; + /// The traits class to build the iterator to store data to shared memory for B^N. - typedef GemmSharedStoreWithSkewTileAbTraits< + typedef GemmSharedStoreWithSkewTileAbTraits < // The pointer. half, // The tile has size KxN in GEMM's terminology. @@ -225,8 +231,8 @@ struct HgemmTileTraitsHelperB // The number of scalars per STS (STS.32 or STS.128, etc). 2, // The skew to avoid bank conflicts added in the tile W dimension. - 128 / sizeof(half) / GlobalTileTraits::Threads::kW / 2> - SharedStoreTileTraits; + kSkewB + SharedStoreTileTraits; /// The traits class to build the iterator to load from shared memory for B^N. typedef GemmSharedLoadTileBTraits< @@ -261,7 +267,7 @@ template < /// The functor to do the math in the epilogue. typename EpilogueFunctor_, /// The number of accumulators per thread. - typename AccumulatorsPerThread_ = Shape<32, 8, 8>, + typename AccumulatorsPerThread_ = Shape<8, 8, 16>, /// The number of halfs loaded in one LDG for A. int kScalarsPerLdgA_ = 2, /// The number of halfs loaded in one LDG for B. diff --git a/cutlass/gemm/igemm_global_tile.h b/cutlass/gemm/igemm_global_tile.h index 6993c631f..3f594ac6a 100644 --- a/cutlass/gemm/igemm_global_tile.h +++ b/cutlass/gemm/igemm_global_tile.h @@ -47,19 +47,19 @@ template -struct IgemmContiguousGlobalTileTraits : public GemmGlobalTileTraits< - // Which GEMM operand? - kOperand_, - // The layout. - kLayout_, - // The scalar. - Scalar_, - // The tile. - Tile_, - // The threads. - Threads_, - // The number of scalars per LDG/STG. - kAccessSize_> { +struct IgemmGlobalTileTraits : public GemmGlobalTileTraits< + // Which GEMM operand? + kOperand_, + // The layout. + kLayout_, + // The scalar. + Scalar_, + // The tile. + Tile_, + // The threads. + Threads_, + // The number of scalars per LDG/STG. + kAccessSize_> { /// The base class. typedef GemmGlobalTileTraits Base; /// The threads. @@ -91,5 +91,71 @@ struct IgemmContiguousGlobalTileTraits : public GemmGlobalTileTraits< //////////////////////////////////////////////////////////////////////////////////////////////////// +/// Deprecated. Please use IgemmGlobalTileTraits instead. + +template +struct IgemmContiguousGlobalTileTraits + : public IgemmGlobalTileTraits {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct IgemmGlobalIteratorAb : public GemmGlobalIteratorAb { + /// The base class. + typedef GemmGlobalIteratorAb Base; + /// The functor to compute the thread offset. + typedef typename TileTraits_::ThreadOffset ThreadOffset; + + /// Constructor. + CUTLASS_DEVICE IgemmGlobalIteratorAb(typename Base::Params const& _params, + const Coord<3>& bounds, + const Coord<3>& block, + ThreadOffset thread_offset_func = ThreadOffset()) + : Base(_params, bounds, block, thread_offset_func), in_residue_(false), mask_(0xffffffff) { + // The number of elements read in a single iteration. + int const kBlock = TileTraits_::Tile::kW * TileTraits_::kAccessSize; + // The residue. + int const kResidue = (int)(bounds[1] % kBlock); + + // Compute the number of elements that are valid. + int const left = kResidue - Base::thread_offset[2]; + if (left > 0 && left < 4) { + mask_ = (1u << (8 * left)) - 1u; + } + } + + /// The accessor. + CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const { + Base::get(value, d, h, w, c); + if (in_residue_) { + reinterpret_cast(value) &= mask_; + } + } + + /// Move to residue portion. + CUTLASS_DEVICE void move_to_residue(typename Base::Index k) { + Base::move_to_residue(k); + in_residue_ = true; + } + + /// Move back to the beginning of the first tile. + CUTLASS_DEVICE void rollback() { + Base::rollback(); + in_residue_ = false; + } + + /// Are we in the residue? + bool in_residue_; + /// The mask to clean up the values. + uint32_t mask_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace gemm } // namespace cutlass diff --git a/cutlass/gemm/igemm_traits.h b/cutlass/gemm/igemm_traits.h index 9e8b93654..82f8de5cd 100644 --- a/cutlass/gemm/igemm_traits.h +++ b/cutlass/gemm/igemm_traits.h @@ -87,7 +87,9 @@ struct IgemmConfig /// The number of scalars per LDS for D. 1, /// The number of stages in shared memory. - 2> {}; + 2, + /// Enable the code path that deals with the residue in epilogue. + true> {}; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -125,17 +127,19 @@ struct IgemmConfig /// The number of scalars per LDS for D. 4, /// The number of stages in shared memory. - 2> {}; + 2, + /// Enable the code path that deals with the residue in epilogue. + true> {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct IgemmTileTraitsHelperA +template +struct IgemmTileTraitsHelperA : public GemmTileTraitsHelperA { /// The base config. typedef GemmTileTraitsHelperA Base; @@ -144,7 +148,7 @@ struct IgemmTileTraitsHelperA static int const kScalarsPerStsA = 16; /// The traits class to build the iterator to load data from global memory for A^N. - typedef IgemmContiguousGlobalTileTraits< + typedef IgemmGlobalTileTraits< GemmOperand::kA, // The layout. MatrixLayout::kColumnMajor, @@ -155,9 +159,12 @@ struct IgemmTileTraitsHelperA // The threads are distributed as warps x 32 (the traits may reorganize). Shape<1, ShapeCount::kCount, GemmConfig_::kWarpSize>, // The number of scalars per LDG (LDG.32 or LDG.128, etc). - 4> + GemmConfig_::kScalarsPerLdgA> GlobalTileTraits; + // The iterator. + typedef GemmGlobalIteratorAb GlobalLoadIterator; + /// The traits class to build the iterator to store data to shared memory for A^N. typedef GemmSharedStoreTileAbTraits< // The pointer is float. @@ -173,13 +180,149 @@ struct IgemmTileTraitsHelperA //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template +struct IgemmTileTraitsHelperA { + /// The layout. + static MatrixLayout::Kind const kLayout = MatrixLayout::kRowMajor; + + /// The input scalar. + typedef int8_t Scalar; + /// The scalar stored in shared memory. + typedef int8_t MultiplyAddScalar; + + /// The number of scalars per LDG/STS/LDS for A. + static int const kScalarsPerStsA = 16; + + /// The traits class to build the iterator to load data from global memory for A^T. + typedef IgemmGlobalTileTraits< + GemmOperand::kA, + // The layout. + MatrixLayout::kRowMajor, + // The pointer is float const. + int8_t const, + // The tile has size NxK in GEMM's terminology. + Shape<1, GemmConfig_::OutputTile::kW, GemmConfig_::OutputTile::kD>, + // The threads are distributed as warps x 32 (the traits may reorganize). + Shape<1, ShapeCount::kCount, GemmConfig_::kWarpSize>, + // The number of scalars per LDG (LDG.32 or LDG.128, etc). + GemmConfig_::kScalarsPerLdgA> + GlobalTileTraits; + + // The iterator. + typedef IgemmGlobalIteratorAb GlobalLoadIterator; + + /// The traits class to build the iterator to store data to shared memory for A^N. + typedef GemmSharedStoreWithSkewTileAbTraits< + // The pointer is int8. + int8_t, + // The tile has size KxN in GEMM's terminology. + Shape, + // The threads are distributed as (threads / K) x K (the traits may reorganize). + typename GlobalTileTraits::Threads, + // The number of scalars per STS. + kScalarsPerStsA, + // The skew to avoid bank conflicts added in the tile W dimension. + 16> + SharedStoreTileTraits; + + /// The traits class to build the iterator to load from shared memory for A^N. + typedef GemmSharedLoadTileATraits< + // The pointer is float const. + int8_t const, + // The output tile size. + typename GemmConfig_::OutputTile, + // The number of warps. + typename GemmConfig_::Warps, + // The number of threads per warp. + typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, + // The shape of the FMA instruction. + typename GemmConfig_::InstructionShape, + // The number of stages. + GemmConfig_::kStages, + // The number of scalars per LDS. + 16, + // The skew. + SharedStoreTileTraits::kSkew> + SharedLoadTileTraits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB {}; //////////////////////////////////////////////////////////////////////////////////////////////////// -template -struct IgemmTileTraitsHelperB +template +struct IgemmTileTraitsHelperB { + /// The layout. + static MatrixLayout::Kind const kLayout = MatrixLayout::kColumnMajor; + + /// The input scalar. + typedef int8_t Scalar; + /// The scalar stored in shared memory. + typedef int8_t MultiplyAddScalar; + + /// The number of scalars per LDG/STS/LDS for B. + static int const kScalarsPerStsB = 16; + + /// The traits class to build the iterator to load data from global memory for B^T. + typedef IgemmGlobalTileTraits< + GemmOperand::kB, + // The layout. + MatrixLayout::kColumnMajor, + // The pointer is float const. + int8_t const, + // The tile has size NxK in GEMM's terminology. + Shape<1, GemmConfig_::OutputTile::kH, GemmConfig_::OutputTile::kD>, + // The threads are distributed as warps x 32 (the traits may reorganize). + Shape<1, ShapeCount::kCount, GemmConfig_::kWarpSize>, + // The number of scalars per LDG (LDG.32 or LDG.128, etc). + GemmConfig_::kScalarsPerLdgB> + GlobalTileTraits; + + // The iterator. + typedef IgemmGlobalIteratorAb GlobalLoadIterator; + + /// The traits class to build the iterator to store data to shared memory for B^N. + typedef GemmSharedStoreWithSkewTileAbTraits< + // The pointer is int8. + int8_t, + // The tile has size KxN in GEMM's terminology. + Shape, + // The threads are distributed as (threads / K) x K (the traits may reorganize). + typename GlobalTileTraits::Threads, + // The number of scalars per STS. + kScalarsPerStsB, + // The skew to avoid bank conflicts added in the tile W dimension. + 16> + SharedStoreTileTraits; + + /// The traits class to build the iterator to load from shared memory for B^N. + typedef GemmSharedLoadTileBTraits< + // The pointer is float const. + int8_t const, + // The output tile size. + typename GemmConfig_::OutputTile, + // The number of warps. + typename GemmConfig_::Warps, + // The number of threads per warp. + typename GemmConfig_::MultiplyAdd::ThreadsPerWarp, + // The shape of the FMA instruction. + typename GemmConfig_::InstructionShape, + // The number of stages. + GemmConfig_::kStages, + // The number of scalars per LDS. + 16, + // The skew. + SharedStoreTileTraits::kSkew> + SharedLoadTileTraits; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct IgemmTileTraitsHelperB : public GemmTileTraitsHelperB { /// The base config. typedef GemmTileTraitsHelperB Base; @@ -188,7 +331,7 @@ struct IgemmTileTraitsHelperB static int const kScalarsPerStsB = 16; /// The traits class to build the iterator to load data from global memory for B^T. - typedef IgemmContiguousGlobalTileTraits< + typedef IgemmGlobalTileTraits< GemmOperand::kB, // The layout. MatrixLayout::kRowMajor, @@ -199,9 +342,12 @@ struct IgemmTileTraitsHelperB // The threads are distributed as warps x 32 (the traits may reorganize). Shape<1, ShapeCount::kCount, GemmConfig_::kWarpSize>, // The number of scalars per LDG (LDG.32 or LDG.128, etc). - 4> + GemmConfig_::kScalarsPerLdgB> GlobalTileTraits; + // The iterator. + typedef GemmGlobalIteratorAb GlobalLoadIterator; + /// The traits class to build the iterator to store data to shared memory for B^N. typedef GemmSharedStoreTileAbTraits< // The pointer is float. @@ -266,13 +412,13 @@ struct IgemmTraitsHelper { /// The IGEMM config. typedef IgemmConfig GemmConfig; /// The GEMM config for A. - typedef IgemmTileTraitsHelperA GemmTileTraitsHelperA; + typedef IgemmTileTraitsHelperA GemmTileTraitsHelperA; /// The GEMM config for B. - typedef IgemmTileTraitsHelperB GemmTileTraitsHelperB; + typedef IgemmTileTraitsHelperB GemmTileTraitsHelperB; /// The iterator to load A from global memory. - typedef GemmGlobalIteratorAb - GlobalLoadIteratorA; + typedef typename GemmTileTraitsHelperA::GlobalLoadIterator GlobalLoadIteratorA; + /// The default transformer for A. typedef typename IgemmTransformerA::Transformer GlobalTransformerA; @@ -287,8 +433,8 @@ struct IgemmTraitsHelper { GlobalLoadStreamA; /// The iterator to load B from global memory. - typedef GemmGlobalIteratorAb - GlobalLoadIteratorB; + typedef typename GemmTileTraitsHelperB::GlobalLoadIterator GlobalLoadIteratorB; + // The default transformer for B. typedef typename IgemmTransformerB::Transformer GlobalTransformerB; diff --git a/cutlass/gemm/linear_scaling.h b/cutlass/gemm/linear_scaling.h index 05afaea19..979c93f96 100644 --- a/cutlass/gemm/linear_scaling.h +++ b/cutlass/gemm/linear_scaling.h @@ -1,4 +1,3 @@ - /*************************************************************************************************** * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. * @@ -61,17 +60,17 @@ struct LinearScaling { CUTLASS_DEVICE LinearScaling(Params const& params) : alpha(params.alpha), beta(params.beta) {} /// Evaluate the functor. - template - CUTLASS_DEVICE void evaluate(Fragment_ const& accum, Fragment_& output) { + template + CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_& output) { FragmentMultiplyAdd mad; mad.multiply(alpha, accum, output); } /// Evaluate the functor. - template - CUTLASS_DEVICE void evaluate(Fragment_ const& accum, Fragment_ const& old, Fragment_& output) { + template + CUTLASS_DEVICE void evaluate(FragmentA_ const& accum, FragmentB_ const& old, FragmentB_& output) { FragmentMultiplyAdd mad; - Fragment_ tmp; + FragmentB_ tmp; mad.multiply(beta, old, tmp); mad.multiply_add(alpha, accum, tmp, output); } diff --git a/cutlass/gemm/wmma_gemm_global_tile.h b/cutlass/gemm/wmma_gemm_global_tile.h index 32d9759a9..dbd57f6b5 100644 --- a/cutlass/gemm/wmma_gemm_global_tile.h +++ b/cutlass/gemm/wmma_gemm_global_tile.h @@ -164,6 +164,13 @@ struct WmmaGemmGlobalIteratorCd : public TileIteratorBaseparams.predicate_offset -= (h + pred_offset); } + /// The accessor. + CUTLASS_DEVICE void get(typename Base::AccessType& value, int d, int h, int w, int c) const { + int const imm = + ComputeOffsetFromStrides::get(0, 0, w, c); + Load::load(value, params.pointer, imm); + } + /// Increment the pointer in the C dimension. CUTLASS_DEVICE void inc_c() {} /// Increment the pointer in the W dimension. @@ -181,18 +188,19 @@ struct WmmaGemmGlobalIteratorCd : public TileIteratorBase::get(d, h, w, 0); + Store::store( + value, params.pointer, imm); + } + /// Test the predicate. CUTLASS_DEVICE bool valid(int d, int h, int w, int c) const { return predicates.at(w) && params.predicate_offset > 0; } - /// Returns the raw pointer - CUTLASS_HOST_DEVICE - Pointer data() { return params.pointer; } - - CUTLASS_HOST_DEVICE - Pointer const data() const { return params.pointer; } - /// The predicates for the row. cutlass::PredicateVector predicates; }; diff --git a/cutlass/iterator_access.h b/cutlass/iterator_access.h index db87e0d13..e94beb734 100644 --- a/cutlass/iterator_access.h +++ b/cutlass/iterator_access.h @@ -45,14 +45,12 @@ CUTLASS_HOST_DEVICE void iterator_load(InputIterator &iterator, Fragment &fragme for (int w = 0; w < InputIterator::Iterations::kW; ++w) { for (int c = 0; c < InputIterator::Iterations::kC; ++c) { if (iterator.valid(d, h, w, c)) { - int const offset = - ComputeOffsetFromStrides::get( - 0, 0, w, c); - Load:: - load(reinterpret_cast( - frag_iterator.at(d, h, w, c)), - iterator.data(), - offset); + iterator.get(reinterpret_cast( + frag_iterator.at(d, h, w, c)), + d, + h, + w, + c); } } if (w < InputIterator::Iterations::kW - 1) { @@ -196,17 +194,12 @@ CUTLASS_HOST_DEVICE void iterator_store(OutputIterator &iterator, Fragment &frag for (int h = 0; h < OutputIterator::Iterations::kH; ++h) { for (int w = 0; w < OutputIterator::Iterations::kW; ++w) { if (iterator.valid(d, h, w, 0)) { - int const offset = - ComputeOffsetFromStrides::get( - d, h, w, 0); - - Store:: - store(reinterpret_cast( - frag_iterator.at(d, h, w, 0)), - iterator.data(), - offset); + iterator.set(reinterpret_cast( + frag_iterator.at(d, h, w, 0)), + d, + h, + w, + 0); } if (w < OutputIterator::Iterations::kW - 1) { iterator.inc_w(); diff --git a/cutlass/load_store.h b/cutlass/load_store.h index d3d0ce81e..5cb5eb672 100644 --- a/cutlass/load_store.h +++ b/cutlass/load_store.h @@ -106,6 +106,29 @@ struct Load { //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(__CUDACC_VERSION_MAJOR) && __CUDACC_VERSION_MAJOR < 10 +// WAR bug in NVCC where the upper and lower half of the register end up being the same +template +struct Load { + /// The output type. + typedef typename Vectorize::Type AccessType; + + /// The store function. + static CUTLASS_DEVICE void load(AccessType& dst, half const* pointer, int offset) { + int2 tmp = reinterpret_cast(&pointer[offset])[0]; + dst.registers[0] = tmp.x; + dst.registers[1] = tmp.y; + + tmp = reinterpret_cast(&pointer[offset + 4])[0]; + dst.registers[2] = tmp.x; + dst.registers[3] = tmp.y; + } +}; + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + template struct Load { /// The output type. diff --git a/cutlass/shape.h b/cutlass/shape.h index f0f63d9c3..4f6b222ee 100644 --- a/cutlass/shape.h +++ b/cutlass/shape.h @@ -150,9 +150,13 @@ struct ShapeMin { //////////////////////////////////////////////////////////////////////////////////////////////////// -template +template struct ShapeStrides { - typedef Shape Shape; + typedef Shape + Shape; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cutlass/tile_iterator.h b/cutlass/tile_iterator.h index 6543cebf9..5d39c4f80 100644 --- a/cutlass/tile_iterator.h +++ b/cutlass/tile_iterator.h @@ -73,7 +73,11 @@ struct IteratorFragment { * @brief A template defining \ref tile_traits_concept * @concept{tile_traits_concept} */ -template +template struct TileTraits { /// Shape of the tile typedef Tile_ Tile; @@ -501,6 +505,13 @@ struct TileLoadIterator : public TileIteratorBase::get(d, h, w, c); + Load::load(value, params.pointer, imm); + } + /// Increment in the D dimension CUTLASS_HOST_DEVICE void inc_d() { params.pointer += params.inc_d; } @@ -829,6 +840,13 @@ struct TileStoreIterator : public TileIteratorBase::get(d, h, w, c); + Store::store(value, params.pointer, imm); + } + public: /// Stores a fragment and advances to the next tile. template diff --git a/cutlass/util/platform.h b/cutlass/util/platform.h index 32c41a67a..2a44c10e6 100644 --- a/cutlass/util/platform.h +++ b/cutlass/util/platform.h @@ -299,7 +299,7 @@ typedef integral_constant true_type; /// The type used as a compile-time boolean with false value. typedef integral_constant false_type; -#if (!defined(_MSC_VER) && (__cplusplus < 201402L)) || (defined(_MSC_VER) && (_MSC_VER < 1900)) +#if (!defined(_MSC_VER) && (__cplusplus <= 201402L)) || (defined(_MSC_VER) && (_MSC_VER < 1900)) /// std::bool_constant template diff --git a/tools/test/unit/CMakeLists.txt b/tools/test/unit/CMakeLists.txt index 783bc1f53..93d0290ec 100644 --- a/tools/test/unit/CMakeLists.txt +++ b/tools/test/unit/CMakeLists.txt @@ -47,6 +47,7 @@ set(CUTLASS_UNIT_TEST_SOURCES core/tile_iterator.cu gemm/dgemm.cu gemm/hgemm_128x128x8.cu + gemm/hgemm_128x128x16.cu gemm/hgemm_128x32x8.cu gemm/hgemm_128x64x8.cu gemm/igemm_128x128x32.cu @@ -54,12 +55,19 @@ set(CUTLASS_UNIT_TEST_SOURCES gemm/igemm_128x32x32.cu gemm/igemm_128x128x32_float.cu gemm/igemm_128x128x32_int8.cu + gemm/igemm_32x32x128.cu gemm/sgemm_128x128x8.cu + gemm/sgemm_128x128x16.cu gemm/sgemm_128x64x8.cu + gemm/sgemm_128x64x16.cu gemm/sgemm_128x32x8.cu + gemm/sgemm_128x32x16.cu gemm/sgemm_64x128x8.cu + gemm/sgemm_64x128x16.cu gemm/sgemm_64x64x8.cu + gemm/sgemm_64x64x16.cu gemm/sgemm_64x32x8.cu + gemm/sgemm_64x32x16.cu gemm/wmma_gemm.cu ) diff --git a/tools/test/unit/cutlass_unit_test.cpp b/tools/test/unit/cutlass_unit_test.cpp index 0b2ab0d34..ec78c8a65 100644 --- a/tools/test/unit/cutlass_unit_test.cpp +++ b/tools/test/unit/cutlass_unit_test.cpp @@ -26,9 +26,26 @@ \brief CUTLASS Unit Tests */ +#include #include +void set_gtest_flag() { + // Default flags can be overwritten by --gtest_filter from commandline + cudaDeviceProp deviceProperties; + cudaGetDeviceProperties(&deviceProperties, 0); + + int deviceMajorMinor = deviceProperties.major * 10 + deviceProperties.minor; + + if (deviceMajorMinor < 53) + ::testing::GTEST_FLAG(filter) = "-*Igemm*:*Hgemm*:*mma*"; + else if (deviceMajorMinor < 61) + ::testing::GTEST_FLAG(filter) = "-*Igemm*:*mma*"; + else if (deviceMajorMinor < 70) + ::testing::GTEST_FLAG(filter) = "-*mma*"; +} + int main(int argc, char* arg[]) { + set_gtest_flag(); ::testing::InitGoogleTest(&argc, arg); return RUN_ALL_TESTS(); } diff --git a/tools/test/unit/gemm/dgemm.cu b/tools/test/unit/gemm/dgemm.cu index cd002b5b2..be78450b9 100644 --- a/tools/test/unit/gemm/dgemm.cu +++ b/tools/test/unit/gemm/dgemm.cu @@ -104,6 +104,64 @@ TEST(Dgemm_128x128x8, dgemm_512x256x64_nt) { //////////////////////////////////////////////////////////////////////////////////////////////////// +//Sliced-K configuration + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Dgemm_64x32x16, dgemm_64x32x16_nt) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(64, 32, 16); +} + +TEST(Dgemm_64x32x16, dgemm_256x128x64_nt) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(256, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Dgemm_64x64x16, dgemm_64x64x16_nt) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(64, 64, 16); +} + +TEST(Dgemm_64x64x16, dgemm_256x128x64_nt) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(256, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Dgemm_128x32x16, dgemm_128x32x8_nt) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(128, 32, 16); +} + +TEST(Dgemm_128x32x16, dgemm_256x64x64_nt) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(256, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // // DGEMM Column-Column // @@ -182,6 +240,64 @@ TEST(Dgemm_128x128x8, dgemm_512x256x64_nn) { //////////////////////////////////////////////////////////////////////////////////////////////////// +// Sliced-K configuration + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Dgemm_64x32x16, dgemm_64x32x16_nn) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(64, 32, 16); +} + +TEST(Dgemm_64x32x16, dgemm_256x128x64_nn) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(256, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Dgemm_64x64x16, dgemm_64x64x16_nn) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(64, 64, 16); +} + +TEST(Dgemm_64x64x16, dgemm_256x128x64_nn) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(256, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Dgemm_128x32x16, dgemm_128x32x16_nn) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(128, 32, 16); +} + +TEST(Dgemm_128x32x16, dgemm_256x64x64_nn) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(256, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // // DGEMM Row-Column // @@ -260,6 +376,64 @@ TEST(Dgemm_128x128x8, dgemm_512x256x64_tn) { //////////////////////////////////////////////////////////////////////////////////////////////////// +// Sliced-K configuration + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Dgemm_64x32x16, dgemm_64x32x16_tn) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(64, 32, 16); +} + +TEST(Dgemm_64x32x16, dgemm_256x128x64_tn) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(256, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Dgemm_64x64x16, dgemm_64x64x16_tn) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(64, 64, 16); +} + +TEST(Dgemm_64x64x16, dgemm_256x128x64_tn) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(256, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Dgemm_128x32x16, dgemm_128x32x8_tn) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(128, 32, 16); +} + +TEST(Dgemm_128x32x16, dgemm_256x64x64_tn) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(256, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + // // DGEMM Row-Row // @@ -338,3 +512,62 @@ TEST(Dgemm_128x128x8, dgemm_512x256x64_tt) { //////////////////////////////////////////////////////////////////////////////////////////////////// +// Sliced-K configuration + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Dgemm_64x32x16, dgemm_64x32x16_tt) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(64, 32, 16); +} + +TEST(Dgemm_64x32x16, dgemm_256x128x64_tt) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(256, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Dgemm_64x64x16, dgemm_64x64x16_tt) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(64, 64, 16); +} + +TEST(Dgemm_64x64x16, dgemm_256x128x64_tt) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(256, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Dgemm_128x32x16, dgemm_128x32x8_tt) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(128, 32, 16); +} + +TEST(Dgemm_128x32x16, dgemm_256x64x64_tt) { + + typedef cutlass::gemm::DgemmTraits > GemmTraits; + run_gemm(256, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + diff --git a/tools/test/unit/gemm/gemm.h b/tools/test/unit/gemm/gemm.h index 2c1ee5eed..78cdbd11b 100644 --- a/tools/test/unit/gemm/gemm.h +++ b/tools/test/unit/gemm/gemm.h @@ -24,12 +24,18 @@ **************************************************************************************************/ #include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// template static void run_gemm( int m, int n, int k, + int lda, + int ldb, + int ldc, typename test::GemmTestbedTraits::host_type alpha = typename test::GemmTestbedTraits::host_type(1), typename test::GemmTestbedTraits::host_type beta = @@ -51,6 +57,9 @@ static void run_gemm( testbed(m, n, k, + lda, + ldb, + ldc, cutlass::convert(GemmTraits_::kLayoutA), cutlass::convert(GemmTraits_::kLayoutB), alpha, @@ -88,3 +97,22 @@ static void run_gemm( ASSERT_TRUE(testbed.verify_with_host()); } } + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +static void run_gemm( + int m, + int n, + int k, + typename test::GemmTestbedTraits::host_type alpha = + typename test::GemmTestbedTraits::host_type(1), + typename test::GemmTestbedTraits::host_type beta = + typename test::GemmTestbedTraits::host_type(0)) { + int lda = GemmTraits_::kLayoutA == cutlass::MatrixLayout::kColumnMajor ? m : k; + int ldb = GemmTraits_::kLayoutB == cutlass::MatrixLayout::kColumnMajor ? k : n; + + run_gemm(m, n, k, lda, ldb, m, alpha, beta); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/unit/gemm/gemm_testbed.h b/tools/test/unit/gemm/gemm_testbed.h index 089e515cb..47e90f61c 100644 --- a/tools/test/unit/gemm/gemm_testbed.h +++ b/tools/test/unit/gemm/gemm_testbed.h @@ -44,6 +44,8 @@ namespace cutlass { +//////////////////////////////////////////////////////////////////////////////////////////////////// + template struct GemmTestbedTraits : public cutlass::TypeTraits {}; @@ -68,6 +72,8 @@ struct GemmTestbedTraits struct GemmTestbed { // @@ -219,11 +225,11 @@ struct GemmTestbed { typedef cutlass::Coord::Rank> Coord_t; + size_t matrix_stride = layout == CUBLAS_OP_N ? columns * ldm : rows * ldm; + // TODO: Remove that (int) cast. Coord_t stride = cutlass::make_Coord( - rows * columns, layout == CUBLAS_OP_N ? 1 : ldm, layout == CUBLAS_OP_N ? ldm : 1, 1); - + (int)matrix_stride, layout == CUBLAS_OP_N ? 1 : ldm, layout == CUBLAS_OP_N ? ldm : 1, 1); Coord_t size = cutlass::make_Coord(1, rows, columns, 1); - tensor.reset(stride, size); } @@ -231,11 +237,13 @@ struct GemmTestbed { // Methods // - /// Constructs a workspace for verifying GEMM, assumes - /// dense packing. + /// Constructs a workspace for verifying GEMM. GemmTestbed(int M_, int N_, int K_, + int lda, + int ldb, + int ldc, cublasOperation_t layout_a, cublasOperation_t layout_b, Scalar alpha_ = Scalar(1), @@ -248,33 +256,6 @@ struct GemmTestbed { throw cutlass::cuda_exception("Failed to create CUBLAS handle"); } - resize(A, M_, K_, layout_a); - resize(B, K_, N_, layout_b); - resize(C_initial, M_, N_, layout_c); - resize(ref_host, M_, N_, layout_c); - resize(ref_cublas, M_, N_, layout_c); - resize(computed, M_, N_, layout_c); - } - - /// Constructs a workspace for verifying GEMM with arbitrary strides - GemmTestbed(int M_, - int N_, - int K_, - int ldc, - cublasOperation_t layout_a, - int lda, - cublasOperation_t layout_b, - int ldb, - Scalar alpha_ = Scalar(1), - Scalar beta_ = Scalar(0), - cublasGemmAlgo_t algorithm_ = CUBLAS_GEMM_DEFAULT, - cublasOperation_t layout_c = CUBLAS_OP_N) - : alpha(alpha_), beta(beta_), algorithm(algorithm_) { - status = cublasCreate(&handle); - if (status != CUBLAS_STATUS_SUCCESS) { - throw cutlass::cuda_exception("Failed to create CUBLAS handle"); - } - resize(A, M_, K_, layout_a, lda); resize(B, K_, N_, layout_b, ldb); resize(C_initial, M_, N_, layout_c, ldc); @@ -515,6 +496,8 @@ struct GemmTestbed { } // namespace test +//////////////////////////////////////////////////////////////////////////////////////////////////// + namespace cutlass { inline cublasOperation_t convert(cutlass::MatrixLayout::Kind layout) { switch (layout) { @@ -527,4 +510,6 @@ inline cublasOperation_t convert(cutlass::MatrixLayout::Kind layout) { } return CUBLAS_OP_N; } + +//////////////////////////////////////////////////////////////////////////////////////////////////// } diff --git a/tools/test/unit/gemm/hgemm_128x128x16.cu b/tools/test/unit/gemm/hgemm_128x128x16.cu new file mode 100644 index 000000000..1d72971d2 --- /dev/null +++ b/tools/test/unit/gemm/hgemm_128x128x16.cu @@ -0,0 +1,347 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_2x2x2_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(2, 2, 2); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x8_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 8); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x16_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x17_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x64_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_256x128x16_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(256, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x256x16_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_256x256x16_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(256, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x16_nn) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x18_nn) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 18); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x64_nn) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_256x128x16_nn) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(256, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x256x16_nn) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_256x256x16_nn) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(256, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x16_tn) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x18_tn) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 18); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x64_tn) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_256x128x16_tn) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(256, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x256x16_tn) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_256x256x16_tn) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(256, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x16_tt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x18_tt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 18); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x64_tt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_256x128x16_tt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(256, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x256x16_tt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_256x256x16_tt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(256, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x16_alpha2_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 16, cutlass::half_t(2), cutlass::half_t(0)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +TEST(Hgemm_128x128x16, hgemm_128x128x16_beta1_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 16, cutlass::half_t(1), cutlass::half_t(1)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_128x128x16_alpha2_beta1_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(128, 128, 16, cutlass::half_t(2), cutlass::half_t(1)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_120x112x64_ldg8_nt) { + // Load 8 halfs per LDG for A/B. + typedef cutlass::gemm::HgemmTraits, + cutlass::gemm::LinearScaling, + cutlass::Shape<8, 8, 16>, + 8, 8> + HgemmTraits; + run_gemm(120, 112, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_508x252x120_ragged_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(508, 252, 120); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_124x126x32_ragged_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(124, 126, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Hgemm_128x128x16, hgemm_124x126x32_ragged_alpha2_beta1_nt) { + typedef cutlass::gemm::HgemmTraits > + HgemmTraits; + run_gemm(124, 126, 32, cutlass::half_t(2), cutlass::half_t(1)); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/tools/test/unit/gemm/hgemm_128x128x8.cu b/tools/test/unit/gemm/hgemm_128x128x8.cu index 9495de8e3..266cce8a1 100644 --- a/tools/test/unit/gemm/hgemm_128x128x8.cu +++ b/tools/test/unit/gemm/hgemm_128x128x8.cu @@ -345,7 +345,7 @@ TEST(Hgemm_128x128x8, hgemm_128x128x16_alpha2_beta1_nt) { TEST(Hgemm_128x128x8, hgemm_120x112x64_ldg8_nt) { // Load 8 halfs per LDG for A/B. typedef cutlass::gemm::HgemmTraits, cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 16>, @@ -367,7 +367,7 @@ TEST(Hgemm_128x128x8, hgemm_508x252x120_ragged_nt) { TEST(Hgemm_128x128x8, hgemm_124x126x32_ragged_nt) { typedef cutlass::gemm::HgemmTraits > HgemmTraits; run_gemm(124, 126, 32); @@ -377,7 +377,7 @@ TEST(Hgemm_128x128x8, hgemm_124x126x32_ragged_nt) { TEST(Hgemm_128x128x8, hgemm_124x126x32_ragged_alpha2_beta1_nt) { typedef cutlass::gemm::HgemmTraits > HgemmTraits; run_gemm(124, 126, 32, cutlass::half_t(2), cutlass::half_t(1)); diff --git a/tools/test/unit/gemm/igemm_128x128x32.cu b/tools/test/unit/gemm/igemm_128x128x32.cu index f31a47fd9..aad3d4929 100644 --- a/tools/test/unit/gemm/igemm_128x128x32.cu +++ b/tools/test/unit/gemm/igemm_128x128x32.cu @@ -25,7 +25,6 @@ #include #include #include -#include #include //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/unit/gemm/igemm_32x32x128.cu b/tools/test/unit/gemm/igemm_32x32x128.cu new file mode 100644 index 000000000..8af1f4e33 --- /dev/null +++ b/tools/test/unit/gemm/igemm_32x32x128.cu @@ -0,0 +1,238 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x4_nt) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 4); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x8_nt) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 8); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x32_nt) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x128_nt) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 128); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x4_nn) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 4); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x8_nn) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 8); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x32_nn) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x128_nn) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 128); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x4_tn) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 4); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x8_tn) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 8); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x15_tn) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 15, 16, 16, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x32_tn) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x128_tn) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 128); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x8_tt) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 8); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x32_tt) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Igemm_32x32x128, igemm_32x32x128_tt) { + typedef cutlass::gemm::IgemmTraits, + int, + cutlass::gemm::LinearScaling, + cutlass::Shape<32, 8, 4> > + IgemmTraits; + run_gemm(32, 32, 128); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/unit/gemm/sgemm_128x128x16.cu b/tools/test/unit/gemm/sgemm_128x128x16.cu new file mode 100644 index 000000000..234a2d976 --- /dev/null +++ b/tools/test/unit/gemm/sgemm_128x128x16.cu @@ -0,0 +1,410 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x128x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x81x1_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 81, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x112x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 112, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x112x17_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 112, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x73x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 73, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_97x112x64_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(97, 112, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_256x112x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 112, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x240x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 240, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_256x240x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 240, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x128x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x112x1_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 112, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_79x112x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(79, 112, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x81x17_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 81, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x112x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 112, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x73x64_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 73, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_256x112x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 112, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x256x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_256x256x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x128x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x128x1_tn) { + typedef cutlass::gemm::SgemmTraits > SgemmTraits; + run_gemm(128, 128, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_127x112x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(127, 112, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_21x112x17_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(21, 112, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x73x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 73, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x81x64_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 81, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_256x112x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 112, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_47x256x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(47, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_211x256x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(211, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x128x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x128x1_tt) { + typedef cutlass::gemm::SgemmTraits > SgemmTraits; + run_gemm(128, 128, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_109x112x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(109, 112, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x112x17_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 112, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x112x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 112, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_123x112x64_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(123, 112, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_256x112x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 112, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x256x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_256x256x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 256, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_120x112x64_ldg4_nt) { + // Load 4 floats per LDG for A/B. + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, + cutlass::Shape<8, 8, 8>, + 4, + 4> + SgemmTraits; + run_gemm(120, 112, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x128x16_alpha2_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16, 2.f, 0.f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x112x16_beta1_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 112, 16, 1.f, 1.f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x128x16, sgemm_128x112x16_alpha2_beta1_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 112, 16, 2.f, 1.f); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/unit/gemm/sgemm_128x128x8.cu b/tools/test/unit/gemm/sgemm_128x128x8.cu index 9d615a98e..51f91217b 100644 --- a/tools/test/unit/gemm/sgemm_128x128x8.cu +++ b/tools/test/unit/gemm/sgemm_128x128x8.cu @@ -334,7 +334,7 @@ TEST(Sgemm_128x128x8, sgemm_256x256x16_tt) { TEST(Sgemm_128x128x8, sgemm_120x112x64_ldg4_nt) { // Load 4 floats per LDG for A/B. typedef cutlass::gemm::SgemmTraits, cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 8>, diff --git a/tools/test/unit/gemm/sgemm_128x32x16.cu b/tools/test/unit/gemm/sgemm_128x32x16.cu new file mode 100644 index 000000000..6b5d80210 --- /dev/null +++ b/tools/test/unit/gemm/sgemm_128x32x16.cu @@ -0,0 +1,294 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x1_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x17_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x32_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_256x32x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x64x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_256x64x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x1_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x17_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x32_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_256x32x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x64x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_256x64x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x1_tn) { + typedef cutlass::gemm::SgemmTraits > SgemmTraits; + run_gemm(128, 128, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x17_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x32_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_256x32x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x64x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_256x64x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x1_tt) { + typedef cutlass::gemm::SgemmTraits > SgemmTraits; + run_gemm(128, 128, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x17_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x32x32_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 32); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_256x32x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_128x64x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x32x16, sgemm_256x64x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/unit/gemm/sgemm_128x64x16.cu b/tools/test/unit/gemm/sgemm_128x64x16.cu new file mode 100644 index 000000000..d49f7b19a --- /dev/null +++ b/tools/test/unit/gemm/sgemm_128x64x16.cu @@ -0,0 +1,285 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x1_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x17_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x64_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_256x64x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x128x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_256x128x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x1_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x8_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x17_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x64_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_256x64x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x128x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_256x128x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x1_tn) { + typedef cutlass::gemm::SgemmTraits > SgemmTraits; + run_gemm(128, 128, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x17_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x64_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_256x64x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x128x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_256x128x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x1_tt) { + typedef cutlass::gemm::SgemmTraits > SgemmTraits; + run_gemm(128, 128, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x17_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x64x64_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_128x128x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_128x64x16, sgemm_256x128x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(256, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/unit/gemm/sgemm_128x64x8.cu b/tools/test/unit/gemm/sgemm_128x64x8.cu index 75c67cceb..fc8185dbb 100644 --- a/tools/test/unit/gemm/sgemm_128x64x8.cu +++ b/tools/test/unit/gemm/sgemm_128x64x8.cu @@ -333,10 +333,10 @@ TEST(Sgemm_128x64x8, sgemm_256x128x16_tt) { TEST(Sgemm_128x64x8, sgemm_128x64x64_8x4_accumulators_nt) { typedef cutlass::gemm::SgemmTraits, cutlass::gemm::LinearScaling, - cutlass::Shape<8, 4, 8> > + cutlass::Shape<8, 8, 8> > SgemmTraits; run_gemm(128, 64, 64); } @@ -345,7 +345,7 @@ TEST(Sgemm_128x64x8, sgemm_128x64x64_8x4_accumulators_nt) { TEST(Sgemm_128x64x8, sgemm_128x64x64_4x8_accumulators_nt) { typedef cutlass::gemm::SgemmTraits, cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > diff --git a/tools/test/unit/gemm/sgemm_64x128x16.cu b/tools/test/unit/gemm/sgemm_64x128x16.cu new file mode 100644 index 000000000..5fdeb1f6f --- /dev/null +++ b/tools/test/unit/gemm/sgemm_64x128x16.cu @@ -0,0 +1,43 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x128x16, sgemm_64x128x64_4x8_accumulators_nt) { + typedef cutlass::gemm::SgemmTraits, + cutlass::gemm::LinearScaling, + cutlass::Shape<8, 8, 4> > + SgemmTraits; + run_gemm(64, 128, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/unit/gemm/sgemm_64x128x8.cu b/tools/test/unit/gemm/sgemm_64x128x8.cu index 6eccefc56..6d3448e0d 100644 --- a/tools/test/unit/gemm/sgemm_64x128x8.cu +++ b/tools/test/unit/gemm/sgemm_64x128x8.cu @@ -32,7 +32,7 @@ TEST(Sgemm_64x128x8, sgemm_64x128x64_4x8_accumulators_nt) { typedef cutlass::gemm::SgemmTraits, cutlass::gemm::LinearScaling, cutlass::Shape<8, 8, 4> > diff --git a/tools/test/unit/gemm/sgemm_64x32x16.cu b/tools/test/unit/gemm/sgemm_64x32x16.cu new file mode 100644 index 000000000..e0f7841a2 --- /dev/null +++ b/tools/test/unit/gemm/sgemm_64x32x16.cu @@ -0,0 +1,277 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x1_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x17_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x64_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_128x32x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x64x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_128x64x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x1_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x17_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x64_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_128x32x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x64x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_128x64x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x17_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x64_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_128x32x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x64x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_128x64x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + TEST(Sgemm_64x32x16, sgemm_64x64x1_tt) { + typedef cutlass::gemm::SgemmTraits > SgemmTraits; + run_gemm(64, 64, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x32x17_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 32, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_128x32x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 32, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_64x64x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x32x16, sgemm_128x64x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/unit/gemm/sgemm_64x64x16.cu b/tools/test/unit/gemm/sgemm_64x64x16.cu new file mode 100644 index 000000000..3dd79e607 --- /dev/null +++ b/tools/test/unit/gemm/sgemm_64x64x16.cu @@ -0,0 +1,294 @@ +/*************************************************************************************************** + * Copyright (c) 2017-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without modification, are permitted + * provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright notice, this list of + * conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright notice, this list of + * conditions and the following disclaimer in the documentation and/or other materials + * provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used + * to endorse or promote products derived from this software without specific prior written + * permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR + * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND + * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, + * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; + * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, + * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x1_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x17_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x64_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_128x64x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x128x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_128x128x16_nt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x1_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x17_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x64_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_128x64x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x128x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_128x128x16_nn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x1_tn) { + typedef cutlass::gemm::SgemmTraits > SgemmTraits; + run_gemm(64, 64, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x17_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x64_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_128x64x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x128x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_128x128x16_tn) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x1_tt) { + typedef cutlass::gemm::SgemmTraits > SgemmTraits; + run_gemm(64, 64, 1); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x17_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 17); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x64x64_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 64, 64); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_128x64x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 64, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_64x128x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(64, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +TEST(Sgemm_64x64x16, sgemm_128x128x16_tt) { + typedef cutlass::gemm::SgemmTraits > + SgemmTraits; + run_gemm(128, 128, 16); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/test/unit/gemm/wmma_gemm.cu b/tools/test/unit/gemm/wmma_gemm.cu index 0698f4e75..6db07afce 100644 --- a/tools/test/unit/gemm/wmma_gemm.cu +++ b/tools/test/unit/gemm/wmma_gemm.cu @@ -63,6 +63,7 @@ TEST(WmmaGemm_128x128x32, wmma_16x16x16_gemm_256x256x128_nt) { //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100 TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_nt) { typedef cutlass::gemm::WmmaGemmTraits(256, 256, 128); } +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100 TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_nt) { typedef cutlass::gemm::WmmaGemmTraits(256, 256, 128); } +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -124,6 +128,7 @@ TEST(WmmaGemm_128x128x32, wmma_16x16x16_gemm_256x256x128_nn) { //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100 TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_nn) { typedef cutlass::gemm::WmmaGemmTraits(256, 256, 128); } +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100 TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_nn) { typedef cutlass::gemm::WmmaGemmTraits(256, 256, 128); } +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -185,6 +193,7 @@ TEST(WmmaGemm_128x128x32, wmma_16x16x16_gemm_256x256x128_tt) { //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100 TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_tt) { typedef cutlass::gemm::WmmaGemmTraits(256, 256, 128); } +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100 TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_tt) { typedef cutlass::gemm::WmmaGemmTraits(256, 256, 128); } +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -246,6 +258,7 @@ TEST(WmmaGemm_128x128x32, wmma_16x16x16_gemm_256x256x128_tn) { //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100 TEST(WmmaGemm_128x128x32, wmma_8x32x16_gemm_256x256x128_tn) { typedef cutlass::gemm::WmmaGemmTraits(256, 256, 128); } +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// +#if defined(CUDA_VERSION) && CUDA_VERSION >= 9100 TEST(WmmaGemm_128x128x32, wmma_32x8x16_gemm_256x256x128_tn) { typedef cutlass::gemm::WmmaGemmTraits(256, 256, 128); } +#endif //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/tools/util/host_tensor.h b/tools/util/host_tensor.h index 2d0763e62..cc9963c22 100644 --- a/tools/util/host_tensor.h +++ b/tools/util/host_tensor.h @@ -109,6 +109,9 @@ class HostTensor : public HostTensorView { host_.clear(); host_.resize(_capacity); + for (size_t i = 0; i < _capacity; ++i) { + host_[i] = T((int)0xdeadbeef); + } device_.reset(_device_memory, _capacity); Base::reset(TensorRef_t(host_.data(), _stride), _size);